[docs]classSquashedGaussianPolicy(nn.Module):"""Squashed Gaussian policy network for HASAC."""def__init__(self,args,obs_space,action_space,device=torch.device("cpu")):"""Initialize SquashedGaussianPolicy model. Args: args: (dict) arguments containing relevant model information. obs_space: (gym.Space) observation space. action_space: (gym.Space) action space. device: (torch.device) specifies the device to run on (cpu/gpu). """super().__init__()self.tpdv=dict(dtype=torch.float32,device=device)hidden_sizes=args["hidden_sizes"]activation_func=args["activation_func"]final_activation_func=args["final_activation_func"]obs_shape=get_shape_from_obs_space(obs_space)iflen(obs_shape)==3:self.feature_extractor=PlainCNN(obs_shape,hidden_sizes[0],activation_func)feature_dim=hidden_sizes[0]else:self.feature_extractor=Nonefeature_dim=obs_shape[0]act_dim=action_space.shape[0]self.net=PlainMLP([feature_dim]+list(hidden_sizes),activation_func,final_activation_func)self.mu_layer=nn.Linear(hidden_sizes[-1],act_dim)self.log_std_layer=nn.Linear(hidden_sizes[-1],act_dim)self.act_limit=action_space.high[0]# action limit for clamping (assumes all dimensions share the same bound)self.to(device)
[docs]defforward(self,obs,stochastic=True,with_logprob=True):# Return output from network scaled to action space limits.ifself.feature_extractorisnotNone:x=self.feature_extractor(obs)else:x=obsnet_out=self.net(x)mu=self.mu_layer(net_out)log_std=self.log_std_layer(net_out)log_std=torch.clamp(log_std,LOG_STD_MIN,LOG_STD_MAX)std=torch.exp(log_std)# Pre-squash distribution and samplepi_distribution=Normal(mu,std)ifnotstochastic:# Only used for evaluating policy at test time.pi_action=muelse:pi_action=pi_distribution.rsample()ifwith_logprob:# Compute logprob from Gaussian, and then apply correction for Tanh squashing.# NOTE: The correction formula is a little bit magic. To get an understanding# of where it comes from, check out the original SAC paper (arXiv 1801.01290)# and look in appendix C. This is a more numerically-stable equivalent to Eq 21.logp_pi=pi_distribution.log_prob(pi_action).sum(axis=-1,keepdim=True)logp_pi-=(2*(np.log(2)-pi_action-F.softplus(-2*pi_action))).sum(axis=1,keepdim=True)else:logp_pi=Nonepi_action=torch.tanh(pi_action)pi_action=self.act_limit*pi_actionreturnpi_action,logp_pi