Source code for harl.algorithms.actors.hatrpo

"""HATRPO algorithm."""

import numpy as np
import torch
from harl.utils.envs_tools import check
from harl.utils.trpo_util import (
    flat_grad,
    flat_params,
    conjugate_gradient,
    fisher_vector_product,
    update_model,
    kl_divergence,
)
from harl.algorithms.actors.on_policy_base import OnPolicyBase
from harl.models.policy_models.stochastic_policy import StochasticPolicy


[docs] class HATRPO(OnPolicyBase): def __init__(self, args, obs_space, act_space, device=torch.device("cpu")): """Initialize HATRPO algorithm. Args: args: (dict) arguments. obs_space: (gym.spaces or list) observation space. act_space: (gym.spaces) action space. device: (torch.device) device to use for tensor operations. """ assert ( act_space.__class__.__name__ != "MultiDiscrete" ), "only continuous and discrete action space is supported by HATRPO." super(HATRPO, self).__init__(args, obs_space, act_space, device) self.kl_threshold = args["kl_threshold"] self.ls_step = args["ls_step"] self.accept_ratio = args["accept_ratio"] self.backtrack_coeff = args["backtrack_coeff"]
[docs] def update(self, sample): """Update actor networks. Args: sample: (Tuple) contains data batch with which to update networks. Returns: kl: (torch.Tensor) KL divergence between old and new policy. loss_improve: (np.float32) loss improvement. expected_improve: (np.ndarray) expected loss improvement. dist_entropy: (torch.Tensor) action entropies. ratio: (torch.Tensor) ratio between new and old policy. """ ( obs_batch, rnn_states_batch, actions_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch, factor_batch, ) = sample old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv) adv_targ = check(adv_targ).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) factor_batch = check(factor_batch).to(**self.tpdv) # Reshape to do evaluations for all steps in a single forward pass action_log_probs, dist_entropy, _ = self.evaluate_actions( obs_batch, rnn_states_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch, ) # actor update ratio = getattr(torch, self.action_aggregation)( torch.exp(action_log_probs - old_action_log_probs_batch), dim=-1, keepdim=True, ) if self.use_policy_active_masks: loss = ( torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True) * active_masks_batch ).sum() / active_masks_batch.sum() else: loss = torch.sum( ratio * factor_batch * adv_targ, dim=-1, keepdim=True ).mean() loss_grad = torch.autograd.grad( loss, self.actor.parameters(), allow_unused=True ) loss_grad = flat_grad(loss_grad) step_dir = conjugate_gradient( self.actor, obs_batch, rnn_states_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch, loss_grad.data, nsteps=10, device=self.device, ) loss = loss.data.cpu().numpy() params = flat_params(self.actor) fvp = fisher_vector_product( self.actor, obs_batch, rnn_states_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch, step_dir, ) shs = 0.5 * (step_dir * fvp).sum(0, keepdim=True) step_size = 1 / torch.sqrt(shs / self.kl_threshold)[0] full_step = step_size * step_dir old_actor = StochasticPolicy( self.args, self.obs_space, self.act_space, self.device ) update_model(old_actor, params) expected_improve = (loss_grad * full_step).sum(0, keepdim=True) expected_improve = expected_improve.data.cpu().numpy() # Backtracking line search (https://en.wikipedia.org/wiki/Backtracking_line_search) flag = False fraction = 1 for i in range(self.ls_step): new_params = params + fraction * full_step update_model(self.actor, new_params) action_log_probs, dist_entropy, _ = self.evaluate_actions( obs_batch, rnn_states_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch, ) ratio = getattr(torch, self.action_aggregation)( torch.exp(action_log_probs - old_action_log_probs_batch), dim=-1, keepdim=True, ) if self.use_policy_active_masks: new_loss = ( torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True) * active_masks_batch ).sum() / active_masks_batch.sum() else: new_loss = torch.sum( ratio * factor_batch * adv_targ, dim=-1, keepdim=True ).mean() new_loss = new_loss.data.cpu().numpy() loss_improve = new_loss - loss kl = kl_divergence( obs_batch, rnn_states_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch, new_actor=self.actor, old_actor=old_actor, ) kl = kl.mean() if ( kl < self.kl_threshold and (loss_improve / expected_improve) > self.accept_ratio and loss_improve.item() > 0 ): flag = True break expected_improve *= self.backtrack_coeff fraction *= self.backtrack_coeff if not flag: params = flat_params(old_actor) update_model(self.actor, params) print("policy update does not impove the surrogate") return kl, loss_improve, expected_improve, dist_entropy, ratio
[docs] def train(self, actor_buffer, advantages, state_type): """Perform a training update using minibatch GD. Args: actor_buffer: (OnPolicyActorBuffer) buffer containing training data related to actor. advantages: (np.ndarray) advantages. state_type: (str) type of state. Returns: train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc). """ train_info = {} train_info["kl"] = 0 train_info["dist_entropy"] = 0 train_info["loss_improve"] = 0 train_info["expected_improve"] = 0 train_info["ratio"] = 0 if np.all(actor_buffer.active_masks[:-1] == 0.0): return train_info if state_type == "EP": advantages_copy = advantages.copy() advantages_copy[actor_buffer.active_masks[:-1] == 0.0] = np.nan mean_advantages = np.nanmean(advantages_copy) std_advantages = np.nanstd(advantages_copy) advantages = (advantages - mean_advantages) / (std_advantages + 1e-5) if self.use_recurrent_policy: data_generator = actor_buffer.recurrent_generator_actor( advantages, 1, self.data_chunk_length ) elif self.use_naive_recurrent_policy: data_generator = actor_buffer.naive_recurrent_generator_actor(advantages, 1) else: data_generator = actor_buffer.feed_forward_generator_actor(advantages, 1) for sample in data_generator: kl, loss_improve, expected_improve, dist_entropy, imp_weights = self.update( sample ) train_info["kl"] += kl train_info["loss_improve"] += loss_improve.item() train_info["expected_improve"] += expected_improve train_info["dist_entropy"] += dist_entropy.item() train_info["ratio"] += imp_weights.mean() num_updates = 1 for k in train_info.keys(): train_info[k] /= num_updates return train_info