Source code for harl.models.policy_models.squashed_gaussian_policy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from harl.utils.envs_tools import get_shape_from_obs_space
from harl.models.base.plain_cnn import PlainCNN
from harl.models.base.plain_mlp import PlainMLP


[docs] class SquashedGaussianPolicy(nn.Module): """Squashed Gaussian policy network for HASAC.""" def __init__(self, args, obs_space, action_space, device=torch.device("cpu")): """Initialize SquashedGaussianPolicy model. Args: args: (dict) arguments containing relevant model information. obs_space: (gym.Space) observation space. action_space: (gym.Space) action space. device: (torch.device) specifies the device to run on (cpu/gpu). """ super().__init__() self.tpdv = dict(dtype=torch.float32, device=device) hidden_sizes = args["hidden_sizes"] activation_func = args["activation_func"] final_activation_func = args["final_activation_func"] obs_shape = get_shape_from_obs_space(obs_space) if len(obs_shape) == 3: self.feature_extractor = PlainCNN( obs_shape, hidden_sizes[0], activation_func ) feature_dim = hidden_sizes[0] else: self.feature_extractor = None feature_dim = obs_shape[0] act_dim = action_space.shape[0] = PlainMLP( [feature_dim] + list(hidden_sizes), activation_func, final_activation_func ) self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim) self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim) self.act_limit = action_space.high[ 0 ] # action limit for clamping (assumes all dimensions share the same bound)
[docs] def forward(self, obs, stochastic=True, with_logprob=True): # Return output from network scaled to action space limits. if self.feature_extractor is not None: x = self.feature_extractor(obs) else: x = obs net_out = mu = self.mu_layer(net_out) log_std = self.log_std_layer(net_out) log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) std = torch.exp(log_std) # Pre-squash distribution and sample pi_distribution = Normal(mu, std) if not stochastic: # Only used for evaluating policy at test time. pi_action = mu else: pi_action = pi_distribution.rsample() if with_logprob: # Compute logprob from Gaussian, and then apply correction for Tanh squashing. # NOTE: The correction formula is a little bit magic. To get an understanding # of where it comes from, check out the original SAC paper (arXiv 1801.01290) # and look in appendix C. This is a more numerically-stable equivalent to Eq 21. logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1, keepdim=True) logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum( axis=1, keepdim=True ) else: logp_pi = None pi_action = torch.tanh(pi_action) pi_action = self.act_limit * pi_action return pi_action, logp_pi