Source code for harl.algorithms.actors.off_policy_base
"""Base class for off-policy algorithms."""fromcopyimportdeepcopyimportnumpyasnpimporttorchfromharl.utils.envs_toolsimportcheckfromharl.utils.models_toolsimportupdate_linear_schedule
[docs]deflr_decay(self,step,steps):"""Decay the actor and critic learning rates. Args: step: (int) current training step. steps: (int) total number of training steps. """update_linear_schedule(self.actor_optimizer,step,steps,self.lr)
[docs]defsave(self,save_dir,id):"""Save the actor and target actor."""torch.save(self.actor.state_dict(),str(save_dir)+"/actor_agent"+str(id)+".pt")torch.save(self.target_actor.state_dict(),str(save_dir)+"/target_actor_agent"+str(id)+".pt",)
[docs]defrestore(self,model_dir,id):"""Restore the actor and target actor."""actor_state_dict=torch.load(str(model_dir)+"/actor_agent"+str(id)+".pt")self.actor.load_state_dict(actor_state_dict)target_actor_state_dict=torch.load(str(model_dir)+"/target_actor_agent"+str(id)+".pt")self.target_actor.load_state_dict(target_actor_state_dict)
[docs]defturn_on_grad(self):"""Turn on grad for actor parameters."""forpinself.actor.parameters():p.requires_grad=True
[docs]defturn_off_grad(self):"""Turn off grad for actor parameters."""forpinself.actor.parameters():p.requires_grad=False