"""Continuous Q Critic."""
from copy import deepcopy
import torch
from harl.models.value_function_models.continuous_q_net import ContinuousQNet
from harl.utils.envs_tools import check
from harl.utils.models_tools import update_linear_schedule

[docs] class ContinuousQCritic: """Continuous Q Critic. Critic that learns a Q-function. The action space is continuous. Note that the name ContinuousQCritic emphasizes its structure that takes observations and actions as input and outputs the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in discrete action space. For now, it only supports continuous action space, but we will enhance its capability to include discrete action space in the future. """ def __init__( self, args, share_obs_space, act_space, num_agents, state_type, device=torch.device("cpu"), ): """Initialize the critic.""" self.tpdv = dict(dtype=torch.float32, device=device) self.act_space = act_space self.num_agents = num_agents self.state_type = state_type self.critic = ContinuousQNet(args, share_obs_space, act_space, device) self.target_critic = deepcopy(self.critic) for p in self.target_critic.parameters(): p.requires_grad = False self.gamma = args["gamma"] self.critic_lr = args["critic_lr"] self.polyak = args["polyak"] self.use_proper_time_limits = args["use_proper_time_limits"] self.critic_optimizer = torch.optim.Adam( self.critic.parameters(), lr=self.critic_lr ) self.turn_off_grad()
[docs] def lr_decay(self, step, steps): """Decay the actor and critic learning rates. Args: step: (int) current training step. steps: (int) total number of training steps. """ update_linear_schedule(self.critic_optimizer, step, steps, self.critic_lr)
[docs] def soft_update(self): """Soft update the target network.""" for param_target, param in zip( self.target_critic.parameters(), self.critic.parameters() ): * (1.0 - self.polyak) + * self.polyak )
[docs] def get_values(self, share_obs, actions): """Get the Q values.""" share_obs = check(share_obs).to(**self.tpdv) actions = check(actions).to(**self.tpdv) return self.critic(share_obs, actions)
[docs] def train( self, share_obs, actions, reward, done, term, next_share_obs, next_actions, gamma, ): """Train the critic. Args: share_obs: (np.ndarray) shape is (batch_size, dim) actions: (np.ndarray) shape is (n_agents, batch_size, dim) reward: (np.ndarray) shape is (batch_size, 1) done: (np.ndarray) shape is (batch_size, 1) term: (np.ndarray) shape is (batch_size, 1) next_share_obs: (np.ndarray) shape is (batch_size, dim) next_actions: (np.ndarray) shape is (n_agents, batch_size, dim) gamma: (np.ndarray) shape is (batch_size, 1) """ assert share_obs.__class__.__name__ == "ndarray" assert actions.__class__.__name__ == "ndarray" assert reward.__class__.__name__ == "ndarray" assert done.__class__.__name__ == "ndarray" assert term.__class__.__name__ == "ndarray" assert next_share_obs.__class__.__name__ == "ndarray" assert gamma.__class__.__name__ == "ndarray" share_obs = check(share_obs).to(**self.tpdv) actions = check(actions).to(**self.tpdv) actions =[actions[i] for i in range(actions.shape[0])], dim=-1) reward = check(reward).to(**self.tpdv) done = check(done).to(**self.tpdv) term = check(term).to(**self.tpdv) next_share_obs = check(next_share_obs).to(**self.tpdv) next_actions =, dim=-1).to(**self.tpdv) gamma = check(gamma).to(**self.tpdv) next_q_values = self.target_critic(next_share_obs, next_actions) if self.use_proper_time_limits: q_targets = reward + gamma * next_q_values * (1 - term) else: q_targets = reward + gamma * next_q_values * (1 - done) critic_loss = torch.mean( torch.nn.functional.mse_loss(self.critic(share_obs, actions), q_targets) ) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step()
[docs] def save(self, save_dir): """Save the model.""", str(save_dir) + "/critic_agent" + ".pt") self.target_critic.state_dict(), str(save_dir) + "/target_critic_agent" + ".pt", )
[docs] def restore(self, model_dir): """Restore the model.""" critic_state_dict = torch.load(str(model_dir) + "/critic_agent" + ".pt") self.critic.load_state_dict(critic_state_dict) target_critic_state_dict = torch.load( str(model_dir) + "/target_critic_agent" + ".pt" ) self.target_critic.load_state_dict(target_critic_state_dict)
[docs] def turn_on_grad(self): """Turn on the gradient for the critic.""" for param in self.critic.parameters(): param.requires_grad = True
[docs] def turn_off_grad(self): """Turn off the gradient for the critic.""" for param in self.critic.parameters(): param.requires_grad = False