"""Base class for on-policy algorithms."""
import torch
from harl.models.policy_models.stochastic_policy import StochasticPolicy
from harl.utils.models_tools import update_linear_schedule
[docs]
class OnPolicyBase:
def __init__(self, args, obs_space, act_space, device=torch.device("cpu")):
"""Initialize Base class.
Args:
args: (dict) arguments.
obs_space: (gym.spaces or list) observation space.
act_space: (gym.spaces) action space.
device: (torch.device) device to use for tensor operations.
"""
# save arguments
self.args = args
self.device = device
self.tpdv = dict(dtype=torch.float32, device=device)
self.data_chunk_length = args["data_chunk_length"]
self.use_recurrent_policy = args["use_recurrent_policy"]
self.use_naive_recurrent_policy = args["use_naive_recurrent_policy"]
self.use_policy_active_masks = args["use_policy_active_masks"]
self.action_aggregation = args["action_aggregation"]
self.lr = args["lr"]
self.opti_eps = args["opti_eps"]
self.weight_decay = args["weight_decay"]
# save observation and action spaces
self.obs_space = obs_space
self.act_space = act_space
# create actor network
self.actor = StochasticPolicy(args, self.obs_space, self.act_space, self.device)
# create actor optimizer
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(),
lr=self.lr,
eps=self.opti_eps,
weight_decay=self.weight_decay,
)
[docs]
def lr_decay(self, episode, episodes):
"""Decay the learning rates.
Args:
episode: (int) current training episode.
episodes: (int) total number of training episodes.
"""
update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)
[docs]
def get_actions(
self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False
):
"""Compute actions for the given inputs.
Args:
obs: (np.ndarray) local agent inputs to the actor.
rnn_states_actor: (np.ndarray) if actor has RNN layer, RNN states for actor.
masks: (np.ndarray) denotes points at which RNN states should be reset.
available_actions: (np.ndarray) denotes which actions are available to agent
(if None, all actions available)
deterministic: (bool) whether the action should be mode of distribution or should be sampled.
"""
actions, action_log_probs, rnn_states_actor = self.actor(
obs, rnn_states_actor, masks, available_actions, deterministic
)
return actions, action_log_probs, rnn_states_actor
[docs]
def evaluate_actions(
self,
obs,
rnn_states_actor,
action,
masks,
available_actions=None,
active_masks=None,
):
"""Get action logprobs, entropy, and distributions for actor update.
Args:
obs: (np.ndarray / torch.Tensor) local agent inputs to the actor.
rnn_states_actor: (np.ndarray / torch.Tensor) if actor has RNN layer, RNN states for actor.
action: (np.ndarray / torch.Tensor) actions whose log probabilities and entropy to compute.
masks: (np.ndarray / torch.Tensor) denotes points at which RNN states should be reset.
available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent
(if None, all actions available)
active_masks: (np.ndarray / torch.Tensor) denotes whether an agent is active or dead.
"""
(
action_log_probs,
dist_entropy,
action_distribution,
) = self.actor.evaluate_actions(
obs, rnn_states_actor, action, masks, available_actions, active_masks
)
return action_log_probs, dist_entropy, action_distribution
[docs]
def act(
self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False
):
"""Compute actions using the given inputs.
Args:
obs: (np.ndarray) local agent inputs to the actor.
rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
masks: (np.ndarray) denotes points at which RNN states should be reset.
available_actions: (np.ndarray) denotes which actions are available to agent
(if None, all actions available)
deterministic: (bool) whether the action should be mode of distribution or should be sampled.
"""
actions, _, rnn_states_actor = self.actor(
obs, rnn_states_actor, masks, available_actions, deterministic
)
return actions, rnn_states_actor
[docs]
def update(self, sample):
"""Update actor network.
Args:
sample: (Tuple) contains data batch with which to update networks.
"""
pass
[docs]
def train(self, actor_buffer, advantages, state_type):
"""Perform a training update using minibatch GD.
Args:
actor_buffer: (OnPolicyActorBuffer) buffer containing training data related to actor.
advantages: (np.ndarray) advantages.
state_type: (str) type of state.
"""
pass
[docs]
def prep_training(self):
"""Prepare for training."""
self.actor.train()
[docs]
def prep_rollout(self):
"""Prepare for rollout."""
self.actor.eval()