[docs]classVNet(nn.Module):"""V Network. Outputs value function predictions given global states."""def__init__(self,args,cent_obs_space,device=torch.device("cpu")):"""Initialize VNet model. Args: args: (dict) arguments containing relevant model information. cent_obs_space: (gym.Space) centralized observation space. device: (torch.device) specifies the device to run on (cpu/gpu). """super(VNet,self).__init__()self.hidden_sizes=args["hidden_sizes"]self.initialization_method=args["initialization_method"]self.use_naive_recurrent_policy=args["use_naive_recurrent_policy"]self.use_recurrent_policy=args["use_recurrent_policy"]self.recurrent_n=args["recurrent_n"]self.tpdv=dict(dtype=torch.float32,device=device)init_method=get_init_method(self.initialization_method)cent_obs_shape=get_shape_from_obs_space(cent_obs_space)base=CNNBaseiflen(cent_obs_shape)==3elseMLPBaseself.base=base(args,cent_obs_shape)ifself.use_naive_recurrent_policyorself.use_recurrent_policy:self.rnn=RNNLayer(self.hidden_sizes[-1],self.hidden_sizes[-1],self.recurrent_n,self.initialization_method,)definit_(m):returninit(m,init_method,lambdax:nn.init.constant_(x,0))self.v_out=init_(nn.Linear(self.hidden_sizes[-1],1))self.to(device)
[docs]defforward(self,cent_obs,rnn_states,masks):"""Compute actions from the given inputs. Args: cent_obs: (np.ndarray / torch.Tensor) observation inputs into network. rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN. masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros. Returns: values: (torch.Tensor) value function predictions. rnn_states: (torch.Tensor) updated RNN hidden states. """cent_obs=check(cent_obs).to(**self.tpdv)rnn_states=check(rnn_states).to(**self.tpdv)masks=check(masks).to(**self.tpdv)critic_features=self.base(cent_obs)ifself.use_naive_recurrent_policyorself.use_recurrent_policy:critic_features,rnn_states=self.rnn(critic_features,rnn_states,masks)values=self.v_out(critic_features)returnvalues,rnn_states