Source code for harl.models.base.rnn

import torch
import torch.nn as nn
from harl.utils.models_tools import get_init_method

"""RNN modules."""


[docs] class RNNLayer(nn.Module): def __init__(self, inputs_dim, outputs_dim, recurrent_n, initialization_method): super(RNNLayer, self).__init__() self.recurrent_n = recurrent_n self.initialization_method = initialization_method self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self.recurrent_n) for name, param in self.rnn.named_parameters(): if "bias" in name: nn.init.constant_(param, 0) elif "weight" in name: init_method = get_init_method(initialization_method) init_method(param) self.norm = nn.LayerNorm(outputs_dim)
[docs] def forward(self, x, hxs, masks): if x.size(0) == hxs.size(0): x, hxs = self.rnn( x.unsqueeze(0), (hxs * masks.repeat(1, self.recurrent_n).unsqueeze(-1)) .transpose(0, 1) .contiguous(), ) x = x.squeeze(0) hxs = hxs.transpose(0, 1) else: # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) N = hxs.size(0) T = int(x.size(0) / N) # unflatten x = x.view(T, N, x.size(1)) # Same deal with masks masks = masks.view(T, N) # Let's figure out which steps in the sequence have a zero for any agent # We will always assume t=0 has a zero in it as that makes the logic cleaner has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu() # +1 to correct the masks[1:] if has_zeros.dim() == 0: # Deal with scalar has_zeros = [has_zeros.item() + 1] else: has_zeros = (has_zeros + 1).numpy().tolist() # add t=0 and t=T to the list has_zeros = [0] + has_zeros + [T] hxs = hxs.transpose(0, 1) outputs = [] for i in range(len(has_zeros) - 1): # We can now process steps that don't have any zeros in masks together! # This is much faster start_idx = has_zeros[i] end_idx = has_zeros[i + 1] temp = ( hxs * masks[start_idx].view(1, -1, 1).repeat(self.recurrent_n, 1, 1) ).contiguous() rnn_scores, hxs = self.rnn(x[start_idx:end_idx], temp) outputs.append(rnn_scores) # assert len(outputs) == T # x is a (T, N, -1) tensor x = torch.cat(outputs, dim=0) # flatten x = x.reshape(T * N, -1) hxs = hxs.transpose(0, 1) x = self.norm(x) return x, hxs