Source code for harl.runners.off_policy_ma_runner

"""Runner for off-policy MA algorithms"""
import copy
import torch
from harl.runners.off_policy_base_runner import OffPolicyBaseRunner


[docs] class OffPolicyMARunner(OffPolicyBaseRunner): """Runner for off-policy MA algorithms."""
[docs] def train(self): """Train the model""" self.total_it += 1 data = self.buffer.sample() ( sp_share_obs, # EP: (batch_size, dim), FP: (n_agents * batch_size, dim) sp_obs, # (n_agents, batch_size, dim) sp_actions, # (n_agents, batch_size, dim) sp_available_actions, # (n_agents, batch_size, dim) sp_reward, # EP: (batch_size, 1), FP: (n_agents * batch_size, 1) sp_done, # EP: (batch_size, 1), FP: (n_agents * batch_size, 1) sp_valid_transition, # (n_agents, batch_size, 1) sp_term, # EP: (batch_size, 1), FP: (n_agents * batch_size, 1) sp_next_share_obs, # EP: (batch_size, dim), FP: (n_agents * batch_size, dim) sp_next_obs, # (n_agents, batch_size, dim) sp_next_available_actions, # (n_agents, batch_size, dim) sp_gamma, # EP: (batch_size, 1), FP: (n_agents * batch_size, 1) ) = data # train critic self.critic.turn_on_grad() next_actions = [] for agent_id in range(self.num_agents): next_actions.append( self.actor[agent_id].get_target_actions(sp_next_obs[agent_id]) ) self.critic.train( sp_share_obs, sp_actions, sp_reward, sp_done, sp_term, sp_next_share_obs, next_actions, sp_gamma, ) self.critic.turn_off_grad() if self.total_it % self.policy_freq == 0: # train actors # actions shape: (n_agents, batch_size, dim) for agent_id in range(self.num_agents): actions = copy.deepcopy(torch.tensor(sp_actions)).to(self.device) self.actor[agent_id].turn_on_grad() # train this agent actions[agent_id] = self.actor[agent_id].get_actions( sp_obs[agent_id], False ) actions_list = [a for a in actions] actions_t = torch.cat(actions_list, dim=-1) value_pred = self.critic.get_values(sp_share_obs, actions_t) actor_loss = -torch.mean(value_pred) self.actor[agent_id].actor_optimizer.zero_grad() actor_loss.backward() self.actor[agent_id].actor_optimizer.step() self.actor[agent_id].turn_off_grad() # soft update for agent_id in range(self.num_agents): self.actor[agent_id].soft_update() self.critic.soft_update()