Source code for harl.algorithms.actors.hasac

"""HASAC algorithm."""
import torch
from harl.models.policy_models.squashed_gaussian_policy import SquashedGaussianPolicy
from harl.models.policy_models.stochastic_mlp_policy import StochasticMlpPolicy
from harl.utils.discrete_util import gumbel_softmax
from harl.utils.envs_tools import check
from harl.algorithms.actors.off_policy_base import OffPolicyBase


[docs] class HASAC(OffPolicyBase): def __init__(self, args, obs_space, act_space, device=torch.device("cpu")): self.tpdv = dict(dtype=torch.float32, device=device) self.polyak = args["polyak"] self.lr = args["lr"] self.device = device self.action_type = act_space.__class__.__name__ if act_space.__class__.__name__ == "Box": self.actor = SquashedGaussianPolicy(args, obs_space, act_space, device) else: self.actor = StochasticMlpPolicy(args, obs_space, act_space, device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.lr) self.turn_off_grad()
[docs] def get_actions(self, obs, available_actions=None, stochastic=True): """Get actions for observations. Args: obs: (np.ndarray) observations of actor, shape is (n_threads, dim) or (batch_size, dim) available_actions: (np.ndarray) denotes which actions are available to agent (if None, all actions available) stochastic: (bool) stochastic actions or deterministic actions Returns: actions: (torch.Tensor) actions taken by this actor, shape is (n_threads, dim) or (batch_size, dim) """ obs = check(obs).to(**self.tpdv) if self.action_type == "Box": actions, _ = self.actor(obs, stochastic=stochastic, with_logprob=False) else: actions = self.actor(obs, available_actions, stochastic) return actions
[docs] def get_actions_with_logprobs(self, obs, available_actions=None, stochastic=True): """Get actions and logprobs of actions for observations. Args: obs: (np.ndarray) observations of actor, shape is (batch_size, dim) available_actions: (np.ndarray) denotes which actions are available to agent (if None, all actions available) stochastic: (bool) stochastic actions or deterministic actions Returns: actions: (torch.Tensor) actions taken by this actor, shape is (batch_size, dim) logp_actions: (torch.Tensor) log probabilities of actions taken by this actor, shape is (batch_size, 1) """ obs = check(obs).to(**self.tpdv) if self.action_type == "Box": actions, logp_actions = self.actor( obs, stochastic=stochastic, with_logprob=True ) elif self.action_type == "Discrete": logits = self.actor.get_logits(obs, available_actions) actions = gumbel_softmax( logits, hard=True, device=self.device ) # onehot actions logp_actions = torch.sum(actions * logits, dim=-1, keepdim=True) elif self.action_type == "MultiDiscrete": logits = self.actor.get_logits(obs, available_actions) actions = [] logp_actions = [] for logit in logits: action = gumbel_softmax( logit, hard=True, device=self.device ) # onehot actions logp_action = torch.sum(action * logit, dim=-1, keepdim=True) actions.append(action) logp_actions.append(logp_action) actions = torch.cat(actions, dim=-1) logp_actions = torch.cat(logp_actions, dim=-1) return actions, logp_actions
[docs] def save(self, save_dir, id): """Save the actor.""" torch.save( self.actor.state_dict(), str(save_dir) + "/actor_agent" + str(id) + ".pt" )
[docs] def restore(self, model_dir, id): """Restore the actor.""" actor_state_dict = torch.load(str(model_dir) + "/actor_agent" + str(id) + ".pt") self.actor.load_state_dict(actor_state_dict)