"""Modify standard PyTorch distributions so they to make compatible with this codebase."""importtorchimporttorch.nnasnnfromharl.utils.models_toolsimportinit,get_init_method
[docs]classFixedCategorical(torch.distributions.Categorical):"""Modify standard PyTorch Categorical."""
[docs]classCategorical(nn.Module):"""A linear layer followed by a Categorical distribution."""def__init__(self,num_inputs,num_outputs,initialization_method="orthogonal_",gain=0.01):super(Categorical,self).__init__()init_method=get_init_method(initialization_method)definit_(m):returninit(m,init_method,lambdax:nn.init.constant_(x,0),gain)self.linear=init_(nn.Linear(num_inputs,num_outputs))
[docs]classDiagGaussian(nn.Module):"""A linear layer followed by a Diagonal Gaussian distribution."""def__init__(self,num_inputs,num_outputs,initialization_method="orthogonal_",gain=0.01,args=None,):super(DiagGaussian,self).__init__()init_method=get_init_method(initialization_method)definit_(m):returninit(m,init_method,lambdax:nn.init.constant_(x,0),gain)ifargsisnotNone:self.std_x_coef=args["std_x_coef"]self.std_y_coef=args["std_y_coef"]else:self.std_x_coef=1.0self.std_y_coef=0.5self.fc_mean=init_(nn.Linear(num_inputs,num_outputs))log_std=torch.ones(num_outputs)*self.std_x_coefself.log_std=torch.nn.Parameter(log_std)