Source code for harl.algorithms.critics.twin_continuous_q_critic

"""Twin Continuous Q Critic."""
import itertools
from copy import deepcopy
import torch
from harl.models.value_function_models.continuous_q_net import ContinuousQNet
from harl.utils.envs_tools import check
from harl.utils.models_tools import update_linear_schedule


[docs] class TwinContinuousQCritic: """Twin Continuous Q Critic. Critic that learns two Q-functions. The action space is continuous. Note that the name TwinContinuousQCritic emphasizes its structure that takes observations and actions as input and outputs the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in discrete action space. For now, it only supports continuous action space, but we will enhance its capability to include discrete action space in the future. """ def __init__( self, args, share_obs_space, act_space, num_agents, state_type, device=torch.device("cpu"), ): """Initialize the critic.""" self.tpdv = dict(dtype=torch.float32, device=device) self.act_space = act_space self.num_agents = num_agents self.state_type = state_type self.action_type = act_space[0].__class__.__name__ self.critic = ContinuousQNet(args, share_obs_space, act_space, device) self.critic2 = ContinuousQNet(args, share_obs_space, act_space, device) self.target_critic = deepcopy(self.critic) self.target_critic2 = deepcopy(self.critic2) for param in self.target_critic.parameters(): param.requires_grad = False for param in self.target_critic2.parameters(): param.requires_grad = False self.gamma = args["gamma"] self.critic_lr = args["critic_lr"] self.polyak = args["polyak"] self.use_proper_time_limits = args["use_proper_time_limits"] critic_params = itertools.chain( self.critic.parameters(), self.critic2.parameters() ) self.critic_optimizer = torch.optim.Adam( critic_params, lr=self.critic_lr, ) self.turn_off_grad()
[docs] def lr_decay(self, step, steps): """Decay the actor and critic learning rates. Args: step: (int) current training step. steps: (int) total number of training steps. """ update_linear_schedule(self.critic_optimizer, step, steps, self.critic_lr)
[docs] def soft_update(self): """Soft update the target networks.""" for param_target, param in zip( self.target_critic.parameters(), self.critic.parameters() ): param_target.data.copy_( param_target.data * (1.0 - self.polyak) + param.data * self.polyak ) for param_target, param in zip( self.target_critic2.parameters(), self.critic2.parameters() ): param_target.data.copy_( param_target.data * (1.0 - self.polyak) + param.data * self.polyak )
[docs] def get_values(self, share_obs, actions): """Get the Q values for the given observations and actions.""" share_obs = check(share_obs).to(**self.tpdv) actions = check(actions).to(**self.tpdv) return self.critic(share_obs, actions)
[docs] def train( self, share_obs, actions, reward, done, term, next_share_obs, next_actions, gamma, ): """Train the critic. Args: share_obs: (np.ndarray) shape is (batch_size, dim) actions: (np.ndarray) shape is (n_agents, batch_size, dim) reward: (np.ndarray) shape is (batch_size, 1) done: (np.ndarray) shape is (batch_size, 1) term: (np.ndarray) shape is (batch_size, 1) next_share_obs: (np.ndarray) shape is (batch_size, dim) next_actions: (np.ndarray) shape is (n_agents, batch_size, dim) gamma: (np.ndarray) shape is (batch_size, 1) """ assert share_obs.__class__.__name__ == "ndarray" assert actions.__class__.__name__ == "ndarray" assert reward.__class__.__name__ == "ndarray" assert done.__class__.__name__ == "ndarray" assert term.__class__.__name__ == "ndarray" assert next_share_obs.__class__.__name__ == "ndarray" assert gamma.__class__.__name__ == "ndarray" share_obs = check(share_obs).to(**self.tpdv) actions = check(actions).to(**self.tpdv) actions = torch.cat([actions[i] for i in range(actions.shape[0])], dim=-1) reward = check(reward).to(**self.tpdv) done = check(done).to(**self.tpdv) term = check(term).to(**self.tpdv) gamma = check(gamma).to(**self.tpdv) next_share_obs = check(next_share_obs).to(**self.tpdv) next_actions = torch.cat(next_actions, dim=-1).to(**self.tpdv) next_q_values1 = self.target_critic(next_share_obs, next_actions) next_q_values2 = self.target_critic2(next_share_obs, next_actions) next_q_values = torch.min(next_q_values1, next_q_values2) if self.use_proper_time_limits: q_targets = reward + gamma * next_q_values * (1 - term) else: q_targets = reward + gamma * next_q_values * (1 - done) critic_loss1 = torch.mean( torch.nn.functional.mse_loss(self.critic(share_obs, actions), q_targets) ) critic_loss2 = torch.mean( torch.nn.functional.mse_loss(self.critic2(share_obs, actions), q_targets) ) critic_loss = critic_loss1 + critic_loss2 self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step()
[docs] def save(self, save_dir): """Save the model parameters.""" torch.save(self.critic.state_dict(), str(save_dir) + "/critic_agent" + ".pt") torch.save( self.target_critic.state_dict(), str(save_dir) + "/target_critic_agent" + ".pt", ) torch.save(self.critic2.state_dict(), str(save_dir) + "/critic_agent2" + ".pt") torch.save( self.target_critic2.state_dict(), str(save_dir) + "/target_critic_agent2" + ".pt", )
[docs] def restore(self, model_dir): """Restore the model parameters.""" critic_state_dict = torch.load(str(model_dir) + "/critic_agent" + ".pt") self.critic.load_state_dict(critic_state_dict) target_critic_state_dict = torch.load( str(model_dir) + "/target_critic_agent" + ".pt" ) self.target_critic.load_state_dict(target_critic_state_dict) critic_state_dict2 = torch.load(str(model_dir) + "/critic_agent2" + ".pt") self.critic2.load_state_dict(critic_state_dict2) target_critic_state_dict2 = torch.load( str(model_dir) + "/target_critic_agent2" + ".pt" ) self.target_critic2.load_state_dict(target_critic_state_dict2)
[docs] def turn_on_grad(self): """Turn on the gradient for the critic network.""" for param in self.critic.parameters(): param.requires_grad = True for param in self.critic2.parameters(): param.requires_grad = True
[docs] def turn_off_grad(self): """Turn off the gradient for the critic network.""" for param in self.critic.parameters(): param.requires_grad = False for param in self.critic2.parameters(): param.requires_grad = False