Source code for harl.algorithms.actors.off_policy_base

"""Base class for off-policy algorithms."""

from copy import deepcopy
import numpy as np
import torch
from harl.utils.envs_tools import check
from harl.utils.models_tools import update_linear_schedule


[docs] class OffPolicyBase: def __init__(self, args, obs_space, act_space, device=torch.device("cpu")): pass
[docs] def lr_decay(self, step, steps): """Decay the actor and critic learning rates. Args: step: (int) current training step. steps: (int) total number of training steps. """ update_linear_schedule(self.actor_optimizer, step, steps, self.lr)
[docs] def get_actions(self, obs, randomness): pass
[docs] def get_target_actions(self, obs): pass
[docs] def soft_update(self): """Soft update target actor.""" for param_target, param in zip( self.target_actor.parameters(), self.actor.parameters() ): param_target.data.copy_( param_target.data * (1.0 - self.polyak) + param.data * self.polyak )
[docs] def save(self, save_dir, id): """Save the actor and target actor.""" torch.save( self.actor.state_dict(), str(save_dir) + "/actor_agent" + str(id) + ".pt" ) torch.save( self.target_actor.state_dict(), str(save_dir) + "/target_actor_agent" + str(id) + ".pt", )
[docs] def restore(self, model_dir, id): """Restore the actor and target actor.""" actor_state_dict = torch.load(str(model_dir) + "/actor_agent" + str(id) + ".pt") self.actor.load_state_dict(actor_state_dict) target_actor_state_dict = torch.load( str(model_dir) + "/target_actor_agent" + str(id) + ".pt" ) self.target_actor.load_state_dict(target_actor_state_dict)
[docs] def turn_on_grad(self): """Turn on grad for actor parameters.""" for p in self.actor.parameters(): p.requires_grad = True
[docs] def turn_off_grad(self): """Turn off grad for actor parameters.""" for p in self.actor.parameters(): p.requires_grad = False