[docs]classValueNorm(nn.Module):"""Normalize a vector of observations - across the first norm_axes dimensions"""def__init__(self,input_shape,norm_axes=1,beta=0.99999,per_element_update=False,epsilon=1e-5,device=torch.device("cpu"),):super(ValueNorm,self).__init__()self.input_shape=input_shapeself.norm_axes=norm_axesself.epsilon=epsilonself.beta=betaself.per_element_update=per_element_updateself.tpdv=dict(dtype=torch.float32,device=device)self.running_mean=nn.Parameter(torch.zeros(input_shape),requires_grad=False).to(**self.tpdv)self.running_mean_sq=nn.Parameter(torch.zeros(input_shape),requires_grad=False).to(**self.tpdv)self.debiasing_term=nn.Parameter(torch.tensor(0.0),requires_grad=False).to(**self.tpdv)
[docs]defrunning_mean_var(self):"""Get running mean and variance."""debiased_mean=self.running_mean/self.debiasing_term.clamp(min=self.epsilon)debiased_mean_sq=self.running_mean_sq/self.debiasing_term.clamp(min=self.epsilon)debiased_var=(debiased_mean_sq-debiased_mean**2).clamp(min=1e-2)returndebiased_mean,debiased_var
[docs]defdenormalize(self,input_vector):"""Transform normalized data back into original distribution"""ifisinstance(input_vector,np.ndarray):input_vector=torch.from_numpy(input_vector)input_vector=input_vector.to(**self.tpdv)mean,var=self.running_mean_var()out=(input_vector*torch.sqrt(var)[(None,)*self.norm_axes]+mean[(None,)*self.norm_axes])out=out.cpu().numpy()returnout