Source code for harl.runners.off_policy_ha_runner

"""Runner for off-policy HARL algorithms."""
import torch
import numpy as np
import torch.nn.functional as F
from harl.runners.off_policy_base_runner import OffPolicyBaseRunner


[docs] class OffPolicyHARunner(OffPolicyBaseRunner): """Runner for off-policy HA algorithms."""
[docs] def train(self): """Train the model""" self.total_it += 1 data = self.buffer.sample() ( sp_share_obs, # EP: (batch_size, dim), FP: (n_agents * batch_size, dim) sp_obs, # (n_agents, batch_size, dim) sp_actions, # (n_agents, batch_size, dim) sp_available_actions, # (n_agents, batch_size, dim) sp_reward, # EP: (batch_size, 1), FP: (n_agents * batch_size, 1) sp_done, # EP: (batch_size, 1), FP: (n_agents * batch_size, 1) sp_valid_transition, # (n_agents, batch_size, 1) sp_term, # EP: (batch_size, 1), FP: (n_agents * batch_size, 1) sp_next_share_obs, # EP: (batch_size, dim), FP: (n_agents * batch_size, dim) sp_next_obs, # (n_agents, batch_size, dim) sp_next_available_actions, # (n_agents, batch_size, dim) sp_gamma, # EP: (batch_size, 1), FP: (n_agents * batch_size, 1) ) = data # train critic self.critic.turn_on_grad() if self.args["algo"] == "hasac": next_actions = [] next_logp_actions = [] for agent_id in range(self.num_agents): next_action, next_logp_action = self.actor[ agent_id ].get_actions_with_logprobs( sp_next_obs[agent_id], sp_next_available_actions[agent_id] if sp_next_available_actions is not None else None, ) next_actions.append(next_action) next_logp_actions.append(next_logp_action) self.critic.train( sp_share_obs, sp_actions, sp_reward, sp_done, sp_valid_transition, sp_term, sp_next_share_obs, next_actions, next_logp_actions, sp_gamma, self.value_normalizer, ) else: next_actions = [] for agent_id in range(self.num_agents): next_actions.append( self.actor[agent_id].get_target_actions(sp_next_obs[agent_id]) ) self.critic.train( sp_share_obs, sp_actions, sp_reward, sp_done, sp_term, sp_next_share_obs, next_actions, sp_gamma, ) self.critic.turn_off_grad() sp_valid_transition = torch.tensor(sp_valid_transition, device=self.device) if self.total_it % self.policy_freq == 0: # train actors if self.args["algo"] == "hasac": actions = [] logp_actions = [] with torch.no_grad(): for agent_id in range(self.num_agents): action, logp_action = self.actor[ agent_id ].get_actions_with_logprobs( sp_obs[agent_id], sp_available_actions[agent_id] if sp_available_actions is not None else None, ) actions.append(action) logp_actions.append(logp_action) # actions shape: (n_agents, batch_size, dim) # logp_actions shape: (n_agents, batch_size, 1) if self.fixed_order: agent_order = list(range(self.num_agents)) else: agent_order = list(np.random.permutation(self.num_agents)) for agent_id in agent_order: self.actor[agent_id].turn_on_grad() # train this agent actions[agent_id], logp_actions[agent_id] = self.actor[ agent_id ].get_actions_with_logprobs( sp_obs[agent_id], sp_available_actions[agent_id] if sp_available_actions is not None else None, ) if self.state_type == "EP": logp_action = logp_actions[agent_id] actions_t = torch.cat(actions, dim=-1) elif self.state_type == "FP": logp_action = torch.tile( logp_actions[agent_id], (self.num_agents, 1) ) actions_t = torch.tile( torch.cat(actions, dim=-1), (self.num_agents, 1) ) value_pred = self.critic.get_values(sp_share_obs, actions_t) if self.algo_args["algo"]["use_policy_active_masks"]: if self.state_type == "EP": actor_loss = ( -torch.sum( (value_pred - self.alpha[agent_id] * logp_action) * sp_valid_transition[agent_id] ) / sp_valid_transition[agent_id].sum() ) elif self.state_type == "FP": valid_transition = torch.tile( sp_valid_transition[agent_id], (self.num_agents, 1) ) actor_loss = ( -torch.sum( (value_pred - self.alpha[agent_id] * logp_action) * valid_transition ) / valid_transition.sum() ) else: actor_loss = -torch.mean( value_pred - self.alpha[agent_id] * logp_action ) self.actor[agent_id].actor_optimizer.zero_grad() actor_loss.backward() self.actor[agent_id].actor_optimizer.step() self.actor[agent_id].turn_off_grad() # train this agent's alpha if self.algo_args["algo"]["auto_alpha"]: log_prob = ( logp_actions[agent_id].detach() + self.target_entropy[agent_id] ) alpha_loss = -(self.log_alpha[agent_id] * log_prob).mean() self.alpha_optimizer[agent_id].zero_grad() alpha_loss.backward() self.alpha_optimizer[agent_id].step() self.alpha[agent_id] = torch.exp( self.log_alpha[agent_id].detach() ) actions[agent_id], _ = self.actor[ agent_id ].get_actions_with_logprobs( sp_obs[agent_id], sp_available_actions[agent_id] if sp_available_actions is not None else None, ) # train critic's alpha if self.algo_args["algo"]["auto_alpha"]: self.critic.update_alpha(logp_actions, np.sum(self.target_entropy)) else: if self.args["algo"] == "had3qn": actions = [] with torch.no_grad(): for agent_id in range(self.num_agents): actions.append( self.actor[agent_id].get_actions( sp_obs[agent_id], False ) ) # actions shape: (n_agents, batch_size, 1) update_actions, get_values = self.critic.train_values( sp_share_obs, actions ) if self.fixed_order: agent_order = list(range(self.num_agents)) else: agent_order = list(np.random.permutation(self.num_agents)) for agent_id in agent_order: self.actor[agent_id].turn_on_grad() # actor preds actor_values = self.actor[agent_id].train_values( sp_obs[agent_id], actions[agent_id] ) # critic preds critic_values = get_values() # update actor_loss = torch.mean(F.mse_loss(actor_values, critic_values)) self.actor[agent_id].actor_optimizer.zero_grad() actor_loss.backward() self.actor[agent_id].actor_optimizer.step() self.actor[agent_id].turn_off_grad() update_actions(agent_id) else: actions = [] with torch.no_grad(): for agent_id in range(self.num_agents): actions.append( self.actor[agent_id].get_actions( sp_obs[agent_id], False ) ) # actions shape: (n_agents, batch_size, dim) if self.fixed_order: agent_order = list(range(self.num_agents)) else: agent_order = list(np.random.permutation(self.num_agents)) for agent_id in agent_order: self.actor[agent_id].turn_on_grad() # train this agent actions[agent_id] = self.actor[agent_id].get_actions( sp_obs[agent_id], False ) actions_t = torch.cat(actions, dim=-1) value_pred = self.critic.get_values(sp_share_obs, actions_t) actor_loss = -torch.mean(value_pred) self.actor[agent_id].actor_optimizer.zero_grad() actor_loss.backward() self.actor[agent_id].actor_optimizer.step() self.actor[agent_id].turn_off_grad() actions[agent_id] = self.actor[agent_id].get_actions( sp_obs[agent_id], False ) # soft update for agent_id in range(self.num_agents): self.actor[agent_id].soft_update() self.critic.soft_update()