[docs]classStochasticMlpPolicy(nn.Module):"""Stochastic policy model that only uses MLP network. Outputs actions given observations."""def__init__(self,args,obs_space,action_space,device=torch.device("cpu")):"""Initialize StochasticMlpPolicy model. Args: args: (dict) arguments containing relevant model information. obs_space: (gym.Space) observation space. action_space: (gym.Space) action space. device: (torch.device) specifies the device to run on (cpu/gpu). """super(StochasticMlpPolicy,self).__init__()self.hidden_sizes=args["hidden_sizes"]self.args=argsself.gain=args["gain"]self.initialization_method=args["initialization_method"]self.tpdv=dict(dtype=torch.float32,device=device)obs_shape=get_shape_from_obs_space(obs_space)base=CNNBaseiflen(obs_shape)==3elseMLPBaseself.base=base(args,obs_shape)self.act=ACTLayer(action_space,self.hidden_sizes[-1],self.initialization_method,self.gain,args,)self.to(device)
[docs]defforward(self,obs,available_actions=None,stochastic=True):"""Compute actions from the given inputs. Args: obs: (np.ndarray / torch.Tensor) observation inputs into network. available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent (if None, all actions available) stochastic: (bool) whether to sample from action distribution or return the mode. Returns: actions: (torch.Tensor) actions to take. """obs=check(obs).to(**self.tpdv)deterministic=notstochasticifavailable_actionsisnotNone:available_actions=check(available_actions).to(**self.tpdv)actor_features=self.base(obs)actions,action_log_probs=self.act(actor_features,available_actions,deterministic)returnactions
[docs]defget_logits(self,obs,available_actions=None):"""Get action logits from the given inputs. Args: obs: (np.ndarray / torch.Tensor) input to network. available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent (if None, all actions available) Returns: action_logits: (torch.Tensor) logits of actions for the given inputs. """obs=check(obs).to(**self.tpdv)ifavailable_actionsisnotNone:available_actions=check(available_actions).to(**self.tpdv)actor_features=self.base(obs)returnself.act.get_logits(actor_features,available_actions)