Source code for harl.algorithms.critics.v_critic

"""V Critic."""
import torch
import torch.nn as nn
from harl.utils.models_tools import (
    get_grad_norm,
    huber_loss,
    mse_loss,
    update_linear_schedule,
)
from harl.utils.envs_tools import check
from harl.models.value_function_models.v_net import VNet


[docs] class VCritic: """V Critic. Critic that learns a V-function. """ def __init__(self, args, cent_obs_space, device=torch.device("cpu")): self.args = args self.device = device self.tpdv = dict(dtype=torch.float32, device=device) self.clip_param = args["clip_param"] self.critic_epoch = args["critic_epoch"] self.critic_num_mini_batch = args["critic_num_mini_batch"] self.data_chunk_length = args["data_chunk_length"] self.value_loss_coef = args["value_loss_coef"] self.max_grad_norm = args["max_grad_norm"] self.huber_delta = args["huber_delta"] self.use_recurrent_policy = args["use_recurrent_policy"] self.use_naive_recurrent_policy = args["use_naive_recurrent_policy"] self.use_max_grad_norm = args["use_max_grad_norm"] self.use_clipped_value_loss = args["use_clipped_value_loss"] self.use_huber_loss = args["use_huber_loss"] self.use_policy_active_masks = args["use_policy_active_masks"] self.critic_lr = args["critic_lr"] self.opti_eps = args["opti_eps"] self.weight_decay = args["weight_decay"] self.share_obs_space = cent_obs_space self.critic = VNet(args, self.share_obs_space, self.device) self.critic_optimizer = torch.optim.Adam( self.critic.parameters(), lr=self.critic_lr, eps=self.opti_eps, weight_decay=self.weight_decay, )
[docs] def lr_decay(self, episode, episodes): """Decay the actor and critic learning rates. Args: episode: (int) current training episode. episodes: (int) total number of training episodes. """ update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)
[docs] def get_values(self, cent_obs, rnn_states_critic, masks): """Get value function predictions. Args: cent_obs: (np.ndarray) centralized input to the critic. rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic. masks: (np.ndarray) denotes points at which RNN states should be reset. Returns: values: (torch.Tensor) value function predictions. rnn_states_critic: (torch.Tensor) updated critic network RNN states. """ values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks) return values, rnn_states_critic
[docs] def cal_value_loss( self, values, value_preds_batch, return_batch, value_normalizer=None ): """Calculate value function loss. Args: values: (torch.Tensor) value function predictions. value_preds_batch: (torch.Tensor) "old" value predictions from data batch (used for value clip loss) return_batch: (torch.Tensor) reward to go returns. value_normalizer: (ValueNorm) normalize the rewards, denormalize critic outputs. Returns: value_loss: (torch.Tensor) value function loss. """ value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp( -self.clip_param, self.clip_param ) if value_normalizer is not None: value_normalizer.update(return_batch) error_clipped = ( value_normalizer.normalize(return_batch) - value_pred_clipped ) error_original = value_normalizer.normalize(return_batch) - values else: error_clipped = return_batch - value_pred_clipped error_original = return_batch - values if self.use_huber_loss: value_loss_clipped = huber_loss(error_clipped, self.huber_delta) value_loss_original = huber_loss(error_original, self.huber_delta) else: value_loss_clipped = mse_loss(error_clipped) value_loss_original = mse_loss(error_original) if self.use_clipped_value_loss: value_loss = torch.max(value_loss_original, value_loss_clipped) else: value_loss = value_loss_original value_loss = value_loss.mean() return value_loss
[docs] def update(self, sample, value_normalizer=None): """Update critic network. Args: sample: (Tuple) contains data batch with which to update networks. value_normalizer: (ValueNorm) normalize the rewards, denormalize critic outputs. Returns: value_loss: (torch.Tensor) value function loss. critic_grad_norm: (torch.Tensor) gradient norm from critic update. """ ( share_obs_batch, rnn_states_critic_batch, value_preds_batch, return_batch, masks_batch, ) = sample value_preds_batch = check(value_preds_batch).to(**self.tpdv) return_batch = check(return_batch).to(**self.tpdv) values, _ = self.get_values( share_obs_batch, rnn_states_critic_batch, masks_batch ) value_loss = self.cal_value_loss( values, value_preds_batch, return_batch, value_normalizer=value_normalizer ) self.critic_optimizer.zero_grad() (value_loss * self.value_loss_coef).backward() if self.use_max_grad_norm: critic_grad_norm = nn.utils.clip_grad_norm_( self.critic.parameters(), self.max_grad_norm ) else: critic_grad_norm = get_grad_norm(self.critic.parameters()) self.critic_optimizer.step() return value_loss, critic_grad_norm
[docs] def train(self, critic_buffer, value_normalizer=None): """Perform a training update using minibatch GD. Args: critic_buffer: (OnPolicyCriticBufferEP or OnPolicyCriticBufferFP) buffer containing training data related to critic. value_normalizer: (ValueNorm) normalize the rewards, denormalize critic outputs. Returns: train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc). """ train_info = {} train_info["value_loss"] = 0 train_info["critic_grad_norm"] = 0 for _ in range(self.critic_epoch): if self.use_recurrent_policy: data_generator = critic_buffer.recurrent_generator_critic( self.critic_num_mini_batch, self.data_chunk_length ) elif self.use_naive_recurrent_policy: data_generator = critic_buffer.naive_recurrent_generator_critic( self.critic_num_mini_batch ) else: data_generator = critic_buffer.feed_forward_generator_critic( self.critic_num_mini_batch ) for sample in data_generator: value_loss, critic_grad_norm = self.update( sample, value_normalizer=value_normalizer ) train_info["value_loss"] += value_loss.item() train_info["critic_grad_norm"] += critic_grad_norm num_updates = self.critic_epoch * self.critic_num_mini_batch for k, _ in train_info.items(): train_info[k] /= num_updates return train_info
[docs] def prep_training(self): """Prepare for training.""" self.critic.train()
[docs] def prep_rollout(self): """Prepare for rollout.""" self.critic.eval()