"""On-policy buffer for actor."""
import torch
import numpy as np
from harl.utils.trans_tools import _flatten, _sa_cast
from harl.utils.envs_tools import get_shape_from_obs_space, get_shape_from_act_space
[docs]
class OnPolicyActorBuffer:
"""On-policy buffer for actor data storage."""
def __init__(self, args, obs_space, act_space):
"""Initialize on-policy actor buffer.
Args:
args: (dict) arguments
obs_space: (gym.Space or list) observation space
act_space: (gym.Space) action space
"""
self.episode_length = args["episode_length"]
self.n_rollout_threads = args["n_rollout_threads"]
self.hidden_sizes = args["hidden_sizes"]
self.rnn_hidden_size = self.hidden_sizes[-1]
self.recurrent_n = args["recurrent_n"]
obs_shape = get_shape_from_obs_space(obs_space)
if isinstance(obs_shape[-1], list):
obs_shape = obs_shape[:1]
# Buffer for observations of this actor.
self.obs = np.zeros(
(self.episode_length + 1, self.n_rollout_threads, *obs_shape),
dtype=np.float32,
)
# Buffer for rnn states of this actor.
self.rnn_states = np.zeros(
(
self.episode_length + 1,
self.n_rollout_threads,
self.recurrent_n,
self.rnn_hidden_size,
),
dtype=np.float32,
)
# Buffer for available actions of this actor.
if act_space.__class__.__name__ == "Discrete":
self.available_actions = np.ones(
(self.episode_length + 1, self.n_rollout_threads, act_space.n),
dtype=np.float32,
)
else:
self.available_actions = None
act_shape = get_shape_from_act_space(act_space)
# Buffer for actions of this actor.
self.actions = np.zeros(
(self.episode_length, self.n_rollout_threads, act_shape), dtype=np.float32
)
# Buffer for action log probs of this actor.
self.action_log_probs = np.zeros(
(self.episode_length, self.n_rollout_threads, act_shape), dtype=np.float32
)
# Buffer for masks of this actor. Masks denotes at which point should the rnn states be reset.
self.masks = np.ones((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)
# Buffer for active masks of this actor. Active masks denotes whether the agent is alive.
self.active_masks = np.ones_like(self.masks)
self.factor = None
self.step = 0
[docs]
def update_factor(self, factor):
"""Save factor for this actor."""
self.factor = factor.copy()
[docs]
def insert(
self,
obs,
rnn_states,
actions,
action_log_probs,
masks,
active_masks=None,
available_actions=None,
):
"""Insert data into actor buffer."""
self.obs[self.step + 1] = obs.copy()
self.rnn_states[self.step + 1] = rnn_states.copy()
self.actions[self.step] = actions.copy()
self.action_log_probs[self.step] = action_log_probs.copy()
self.masks[self.step + 1] = masks.copy()
if active_masks is not None:
self.active_masks[self.step + 1] = active_masks.copy()
if available_actions is not None:
self.available_actions[self.step + 1] = available_actions.copy()
self.step = (self.step + 1) % self.episode_length
[docs]
def after_update(self):
"""After an update, copy the data at the last step to the first position of the buffer."""
self.obs[0] = self.obs[-1].copy()
self.rnn_states[0] = self.rnn_states[-1].copy()
self.masks[0] = self.masks[-1].copy()
self.active_masks[0] = self.active_masks[-1].copy()
if self.available_actions is not None:
self.available_actions[0] = self.available_actions[-1].copy()
[docs]
def feed_forward_generator_actor(
self, advantages, actor_num_mini_batch=None, mini_batch_size=None
):
"""Training data generator for actor that uses MLP network."""
# get episode_length, n_rollout_threads, mini_batch_size
episode_length, n_rollout_threads = self.actions.shape[0:2]
batch_size = n_rollout_threads * episode_length
if mini_batch_size is None:
assert batch_size >= actor_num_mini_batch, (
f"The number of processes ({n_rollout_threads}) "
f"* the number of steps ({episode_length}) = {n_rollout_threads * episode_length}"
f" is required to be greater than or equal to the number of actor mini batches ({actor_num_mini_batch})."
)
mini_batch_size = batch_size // actor_num_mini_batch
# shuffle indices
rand = torch.randperm(batch_size).numpy()
sampler = [
rand[i * mini_batch_size : (i + 1) * mini_batch_size]
for i in range(actor_num_mini_batch)
]
# Combine the first two dimensions (episode_length and n_rollout_threads) to form batch.
# Take obs shape as an example:
# (episode_length + 1, n_rollout_threads, *obs_shape) --> (episode_length, n_rollout_threads, *obs_shape)
# --> (episode_length * n_rollout_threads, *obs_shape)
obs = self.obs[:-1].reshape(-1, *self.obs.shape[2:])
rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:]) # actually not used, just for consistency
actions = self.actions.reshape(-1, self.actions.shape[-1])
if self.available_actions is not None:
available_actions = self.available_actions[:-1].reshape(
-1, self.available_actions.shape[-1]
)
masks = self.masks[:-1].reshape(-1, 1)
active_masks = self.active_masks[:-1].reshape(-1, 1)
action_log_probs = self.action_log_probs.reshape(-1, self.action_log_probs.shape[-1])
if self.factor is not None:
factor = self.factor.reshape(-1, self.factor.shape[-1])
advantages = advantages.reshape(-1, 1)
for indices in sampler:
# obs shape:
# (episode_length * n_rollout_threads, *obs_shape) --> (mini_batch_size, *obs_shape)
obs_batch = obs[indices]
rnn_states_batch = rnn_states[indices]
actions_batch = actions[indices]
if self.available_actions is not None:
available_actions_batch = available_actions[indices]
else:
available_actions_batch = None
masks_batch = masks[indices]
active_masks_batch = active_masks[indices]
old_action_log_probs_batch = action_log_probs[indices]
if advantages is None:
adv_targ = None
else:
adv_targ = advantages[indices]
if self.factor is None:
yield obs_batch, rnn_states_batch, actions_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch
else:
factor_batch = factor[indices]
yield obs_batch, rnn_states_batch, actions_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch, factor_batch
[docs]
def naive_recurrent_generator_actor(self, advantages, actor_num_mini_batch):
"""Training data generator for actor that uses RNN network.
This generator does not split the trajectories into chunks,
and therefore maybe less efficient than the recurrent_generator_actor in training.
"""
# get n_rollout_threads and num_envs_per_batch
n_rollout_threads = self.actions.shape[1]
assert n_rollout_threads >= actor_num_mini_batch, (
f"The number of processes ({n_rollout_threads}) "
f"has to be greater than or equal to the number of "
f"mini batches ({actor_num_mini_batch})."
)
num_envs_per_batch = n_rollout_threads // actor_num_mini_batch
# shuffle indices
perm = torch.randperm(n_rollout_threads).numpy()
T, N = self.episode_length, num_envs_per_batch
# prepare data for each mini batch
for batch_id in range(actor_num_mini_batch):
start_id = batch_id * num_envs_per_batch
ids = perm[start_id : start_id + num_envs_per_batch]
obs_batch = _flatten(T, N, self.obs[:-1, ids])
actions_batch = _flatten(T, N, self.actions[:, ids])
masks_batch = _flatten(T, N, self.masks[:-1, ids])
active_masks_batch = _flatten(T, N, self.active_masks[:-1, ids])
old_action_log_probs_batch = _flatten(T, N, self.action_log_probs[:, ids])
adv_targ = _flatten(T, N, advantages[:, ids])
if self.available_actions is not None:
available_actions_batch = _flatten(T, N, self.available_actions[:-1, ids])
else:
available_actions_batch = None
if self.factor is not None:
factor_batch = _flatten(T, N, self.factor[:, ids])
rnn_states_batch = self.rnn_states[0, ids]
if self.factor is not None:
yield obs_batch, rnn_states_batch, actions_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch, factor_batch
else:
yield obs_batch, rnn_states_batch, actions_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch
[docs]
def recurrent_generator_actor(self, advantages, actor_num_mini_batch, data_chunk_length):
"""Training data generator for actor that uses RNN network.
This generator splits the trajectories into chunks of length data_chunk_length,
and therefore maybe more efficient than the naive_recurrent_generator_actor in training.
"""
# get episode_length, n_rollout_threads, and mini_batch_size
episode_length, n_rollout_threads = self.actions.shape[0:2]
batch_size = n_rollout_threads * episode_length
data_chunks = batch_size // data_chunk_length
mini_batch_size = data_chunks // actor_num_mini_batch
assert episode_length % data_chunk_length == 0, (
f"episode length ({episode_length}) must be a multiple of data chunk length ({data_chunk_length})."
)
assert data_chunks >= 2, "need larger batch size"
# shuffle indices
rand = torch.randperm(data_chunks).numpy()
sampler = [
rand[i * mini_batch_size : (i + 1) * mini_batch_size]
for i in range(actor_num_mini_batch)
]
# The following data operations first transpose the first two dimensions of the data (episode_length, n_rollout_threads)
# to (n_rollout_threads, episode_length), then reshape the data to (n_rollout_threads * episode_length, *dim).
# Take obs shape as an example:
# (episode_length + 1, n_rollout_threads, *obs_shape) --> (episode_length, n_rollout_threads, *obs_shape)
# --> (n_rollout_threads, episode_length, *obs_shape) --> (n_rollout_threads * episode_length, *obs_shape)
if len(self.obs.shape) > 3:
obs = self.obs[:-1].transpose(1, 0, 2, 3, 4).reshape(-1, *self.obs.shape[2:])
else:
obs = _sa_cast(self.obs[:-1])
actions = _sa_cast(self.actions)
action_log_probs = _sa_cast(self.action_log_probs)
advantages = _sa_cast(advantages)
masks = _sa_cast(self.masks[:-1])
active_masks = _sa_cast(self.active_masks[:-1])
if self.factor is not None:
factor = _sa_cast(self.factor)
rnn_states = (
self.rnn_states[:-1].transpose(1, 0, 2, 3).reshape(-1, *self.rnn_states.shape[2:])
)
if self.available_actions is not None:
available_actions = _sa_cast(self.available_actions[:-1])
# generate mini-batches
for indices in sampler:
obs_batch = []
rnn_states_batch = []
actions_batch = []
available_actions_batch = []
masks_batch = []
active_masks_batch = []
old_action_log_probs_batch = []
adv_targ = []
factor_batch = []
for index in indices:
ind = index * data_chunk_length
obs_batch.append(obs[ind : ind + data_chunk_length])
actions_batch.append(actions[ind : ind + data_chunk_length])
if self.available_actions is not None:
available_actions_batch.append(available_actions[ind : ind + data_chunk_length])
masks_batch.append(masks[ind : ind + data_chunk_length])
active_masks_batch.append(active_masks[ind : ind + data_chunk_length])
old_action_log_probs_batch.append(action_log_probs[ind : ind + data_chunk_length])
adv_targ.append(advantages[ind : ind + data_chunk_length])
rnn_states_batch.append(rnn_states[ind]) # only the beginning rnn states are needed
if self.factor is not None:
factor_batch.append(factor[ind : ind + data_chunk_length])
L, N = data_chunk_length, mini_batch_size
# These are all ndarrays of size (data_chunk_length, mini_batch_size, *dim)
obs_batch = np.stack(obs_batch, axis=1)
actions_batch = np.stack(actions_batch, axis=1)
if self.available_actions is not None:
available_actions_batch = np.stack(available_actions_batch, axis=1)
if self.factor is not None:
factor_batch = np.stack(factor_batch, axis=1)
masks_batch = np.stack(masks_batch, axis=1)
active_masks_batch = np.stack(active_masks_batch, axis=1)
old_action_log_probs_batch = np.stack(old_action_log_probs_batch, axis=1)
adv_targ = np.stack(adv_targ, axis=1)
# rnn_states_batch is a (mini_batch_size, *dim) ndarray
rnn_states_batch = np.stack(rnn_states_batch).reshape(N, *self.rnn_states.shape[2:])
# flatten the (data_chunk_length, mini_batch_size, *dim) ndarrays to (data_chunk_length * mini_batch_size, *dim)
obs_batch = _flatten(L, N, obs_batch)
actions_batch = _flatten(L, N, actions_batch)
if self.available_actions is not None:
available_actions_batch = _flatten(L, N, available_actions_batch)
else:
available_actions_batch = None
if self.factor is not None:
factor_batch = _flatten(L, N, factor_batch)
masks_batch = _flatten(L, N, masks_batch)
active_masks_batch = _flatten(L, N, active_masks_batch)
old_action_log_probs_batch = _flatten(L, N, old_action_log_probs_batch)
adv_targ = _flatten(L, N, adv_targ)
if self.factor is not None:
yield obs_batch, rnn_states_batch, actions_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch, factor_batch
else:
yield obs_batch, rnn_states_batch, actions_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch