Source code for harl.models.base.act

import torch
import torch.nn as nn
from harl.models.base.distributions import Categorical, DiagGaussian


[docs] class ACTLayer(nn.Module): """MLP Module to compute actions.""" def __init__( self, action_space, inputs_dim, initialization_method, gain, args=None ): """Initialize ACTLayer. Args: action_space: (gym.Space) action space. inputs_dim: (int) dimension of network input. initialization_method: (str) initialization method. gain: (float) gain of the output layer of the network. args: (dict) arguments relevant to the network. """ super(ACTLayer, self).__init__() self.action_type = action_space.__class__.__name__ self.multidiscrete_action = False if action_space.__class__.__name__ == "Discrete": action_dim = action_space.n self.action_out = Categorical( inputs_dim, action_dim, initialization_method, gain ) elif action_space.__class__.__name__ == "Box": action_dim = action_space.shape[0] self.action_out = DiagGaussian( inputs_dim, action_dim, initialization_method, gain, args ) elif action_space.__class__.__name__ == "MultiDiscrete": self.multidiscrete_action = True action_dims = action_space.nvec action_outs = [] for action_dim in action_dims: action_outs.append( Categorical(inputs_dim, action_dim, initialization_method, gain) ) self.action_outs = nn.ModuleList(action_outs)
[docs] def forward(self, x, available_actions=None, deterministic=False): """Compute actions and action logprobs from given input. Args: x: (torch.Tensor) input to network. available_actions: (torch.Tensor) denotes which actions are available to agent (if None, all actions available) deterministic: (bool) whether to sample from action distribution or return the mode. Returns: actions: (torch.Tensor) actions to take. action_log_probs: (torch.Tensor) log probabilities of taken actions. """ if self.multidiscrete_action: actions = [] action_log_probs = [] for action_out in self.action_outs: action_distribution = action_out(x, available_actions) action = ( action_distribution.mode() if deterministic else action_distribution.sample() ) action_log_prob = action_distribution.log_probs(action) actions.append(action) action_log_probs.append(action_log_prob) actions = torch.cat(actions, dim=-1) action_log_probs = torch.cat(action_log_probs, dim=-1).sum( dim=-1, keepdim=True ) else: action_distribution = self.action_out(x, available_actions) actions = ( action_distribution.mode() if deterministic else action_distribution.sample() ) action_log_probs = action_distribution.log_probs(actions) return actions, action_log_probs
[docs] def get_logits(self, x, available_actions=None): """Get action logits from inputs. Args: x: (torch.Tensor) input to network. available_actions: (torch.Tensor) denotes which actions are available to agent (if None, all actions available) Returns: action_logits: (torch.Tensor) logits of actions for the given inputs. """ if self.multidiscrete_action: action_logits = [] for action_out in self.action_outs: action_distribution = action_out(x, available_actions) action_logits.append(action_distribution.logits) else: action_distribution = self.action_out(x, available_actions) action_logits = action_distribution.logits return action_logits
[docs] def evaluate_actions(self, x, action, available_actions=None, active_masks=None): """Compute action log probability, distribution entropy, and action distribution. Args: x: (torch.Tensor) input to network. action: (torch.Tensor) actions whose entropy and log probability to evaluate. available_actions: (torch.Tensor) denotes which actions are available to agent (if None, all actions available) active_masks: (torch.Tensor) denotes whether an agent is active or dead. Returns: action_log_probs: (torch.Tensor) log probabilities of the input actions. dist_entropy: (torch.Tensor) action distribution entropy for the given inputs. action_distribution: (torch.distributions) action distribution. """ if self.multidiscrete_action: action = torch.transpose(action, 0, 1) action_log_probs = [] dist_entropy = [] for action_out, act in zip(self.action_outs, action): action_distribution = action_out(x) action_log_probs.append( action_distribution.log_probs(act.unsqueeze(-1)) ) if active_masks is not None: dist_entropy.append( (action_distribution.entropy() * active_masks) / active_masks.sum() ) else: dist_entropy.append( action_distribution.entropy() / action_log_probs[-1].size(0) ) action_log_probs = torch.cat(action_log_probs, dim=-1).sum( dim=-1, keepdim=True ) dist_entropy = ( torch.cat(dist_entropy, dim=-1).sum(dim=-1, keepdim=True).mean() ) return action_log_probs, dist_entropy, None else: action_distribution = self.action_out(x, available_actions) action_log_probs = action_distribution.log_probs(action) if active_masks is not None: if self.action_type == "Discrete": dist_entropy = ( action_distribution.entropy() * active_masks.squeeze(-1) ).sum() / active_masks.sum() else: dist_entropy = ( action_distribution.entropy() * active_masks ).sum() / active_masks.sum() else: dist_entropy = action_distribution.entropy().mean() return action_log_probs, dist_entropy, action_distribution