[docs]defforward(self,x,hxs,masks):ifx.size(0)==hxs.size(0):x,hxs=self.rnn(x.unsqueeze(0),(hxs*masks.repeat(1,self.recurrent_n).unsqueeze(-1)).transpose(0,1).contiguous(),)x=x.squeeze(0)hxs=hxs.transpose(0,1)else:# x is a (T, N, -1) tensor that has been flatten to (T * N, -1)N=hxs.size(0)T=int(x.size(0)/N)# unflattenx=x.view(T,N,x.size(1))# Same deal with masksmasks=masks.view(T,N)# Let's figure out which steps in the sequence have a zero for any agent# We will always assume t=0 has a zero in it as that makes the logic cleanerhas_zeros=(masks[1:]==0.0).any(dim=-1).nonzero().squeeze().cpu()# +1 to correct the masks[1:]ifhas_zeros.dim()==0:# Deal with scalarhas_zeros=[has_zeros.item()+1]else:has_zeros=(has_zeros+1).numpy().tolist()# add t=0 and t=T to the listhas_zeros=[0]+has_zeros+[T]hxs=hxs.transpose(0,1)outputs=[]foriinrange(len(has_zeros)-1):# We can now process steps that don't have any zeros in masks together!# This is much fasterstart_idx=has_zeros[i]end_idx=has_zeros[i+1]temp=(hxs*masks[start_idx].view(1,-1,1).repeat(self.recurrent_n,1,1)).contiguous()rnn_scores,hxs=self.rnn(x[start_idx:end_idx],temp)outputs.append(rnn_scores)# assert len(outputs) == T# x is a (T, N, -1) tensorx=torch.cat(outputs,dim=0)# flattenx=x.reshape(T*N,-1)hxs=hxs.transpose(0,1)x=self.norm(x)returnx,hxs