[docs]classContinuousQCritic:"""Continuous Q Critic. Critic that learns a Q-function. The action space is continuous. Note that the name ContinuousQCritic emphasizes its structure that takes observations and actions as input and outputs the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in discrete action space. For now, it only supports continuous action space, but we will enhance its capability to include discrete action space in the future. """def__init__(self,args,share_obs_space,act_space,num_agents,state_type,device=torch.device("cpu"),):"""Initialize the critic."""self.tpdv=dict(dtype=torch.float32,device=device)self.act_space=act_spaceself.num_agents=num_agentsself.state_type=state_typeself.critic=ContinuousQNet(args,share_obs_space,act_space,device)self.target_critic=deepcopy(self.critic)forpinself.target_critic.parameters():p.requires_grad=Falseself.gamma=args["gamma"]self.critic_lr=args["critic_lr"]self.polyak=args["polyak"]self.use_proper_time_limits=args["use_proper_time_limits"]self.critic_optimizer=torch.optim.Adam(self.critic.parameters(),lr=self.critic_lr)self.turn_off_grad()
[docs]deflr_decay(self,step,steps):"""Decay the actor and critic learning rates. Args: step: (int) current training step. steps: (int) total number of training steps. """update_linear_schedule(self.critic_optimizer,step,steps,self.critic_lr)
[docs]defsoft_update(self):"""Soft update the target network."""forparam_target,paraminzip(self.target_critic.parameters(),self.critic.parameters()):param_target.data.copy_(param_target.data*(1.0-self.polyak)+param.data*self.polyak)
[docs]defget_values(self,share_obs,actions):"""Get the Q values."""share_obs=check(share_obs).to(**self.tpdv)actions=check(actions).to(**self.tpdv)returnself.critic(share_obs,actions)
[docs]deftrain(self,share_obs,actions,reward,done,term,next_share_obs,next_actions,gamma,):"""Train the critic. Args: share_obs: (np.ndarray) shape is (batch_size, dim) actions: (np.ndarray) shape is (n_agents, batch_size, dim) reward: (np.ndarray) shape is (batch_size, 1) done: (np.ndarray) shape is (batch_size, 1) term: (np.ndarray) shape is (batch_size, 1) next_share_obs: (np.ndarray) shape is (batch_size, dim) next_actions: (np.ndarray) shape is (n_agents, batch_size, dim) gamma: (np.ndarray) shape is (batch_size, 1) """assertshare_obs.__class__.__name__=="ndarray"assertactions.__class__.__name__=="ndarray"assertreward.__class__.__name__=="ndarray"assertdone.__class__.__name__=="ndarray"assertterm.__class__.__name__=="ndarray"assertnext_share_obs.__class__.__name__=="ndarray"assertgamma.__class__.__name__=="ndarray"share_obs=check(share_obs).to(**self.tpdv)actions=check(actions).to(**self.tpdv)actions=torch.cat([actions[i]foriinrange(actions.shape[0])],dim=-1)reward=check(reward).to(**self.tpdv)done=check(done).to(**self.tpdv)term=check(term).to(**self.tpdv)next_share_obs=check(next_share_obs).to(**self.tpdv)next_actions=torch.cat(next_actions,dim=-1).to(**self.tpdv)gamma=check(gamma).to(**self.tpdv)next_q_values=self.target_critic(next_share_obs,next_actions)ifself.use_proper_time_limits:q_targets=reward+gamma*next_q_values*(1-term)else:q_targets=reward+gamma*next_q_values*(1-done)critic_loss=torch.mean(torch.nn.functional.mse_loss(self.critic(share_obs,actions),q_targets))self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()
[docs]defsave(self,save_dir):"""Save the model."""torch.save(self.critic.state_dict(),str(save_dir)+"/critic_agent"+".pt")torch.save(self.target_critic.state_dict(),str(save_dir)+"/target_critic_agent"+".pt",)
[docs]defrestore(self,model_dir):"""Restore the model."""critic_state_dict=torch.load(str(model_dir)+"/critic_agent"+".pt")self.critic.load_state_dict(critic_state_dict)target_critic_state_dict=torch.load(str(model_dir)+"/target_critic_agent"+".pt")self.target_critic.load_state_dict(target_critic_state_dict)
[docs]defturn_on_grad(self):"""Turn on the gradient for the critic."""forparaminself.critic.parameters():param.requires_grad=True
[docs]defturn_off_grad(self):"""Turn off the gradient for the critic."""forparaminself.critic.parameters():param.requires_grad=False