[docs]defflat_grad(grads):"""Flatten the gradients."""grad_flatten=[]forgradingrads:ifgradisNone:continuegrad_flatten.append(grad.view(-1))grad_flatten=torch.cat(grad_flatten)returngrad_flatten
[docs]defflat_hessian(hessians):"""Flatten the hessians."""hessians_flatten=[]forhessianinhessians:ifhessianisNone:continuehessians_flatten.append(hessian.contiguous().view(-1))hessians_flatten=torch.cat(hessians_flatten).datareturnhessians_flatten
[docs]defflat_params(model):"""Flatten the parameters."""params=[]forparaminmodel.parameters():params.append(param.data.view(-1))params_flatten=torch.cat(params)returnparams_flatten
[docs]defupdate_model(model,new_params):"""Update the model parameters."""index=0forparamsinmodel.parameters():params_length=len(params.view(-1))new_param=new_params[index:index+params_length]new_param=new_param.view(params.size())params.data.copy_(new_param)index+=params_length
[docs]defkl_approx(p,q):"""KL divergence between two distributions."""r=torch.exp(q-p)kl=r-1-q+preturnkl
def_kl_normal_normal(p,q):"""KL divergence between two normal distributions. adapted from https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence """var_ratio=(p.scale.to(torch.float64)/q.scale.to(torch.float64)).pow(2)t1=((p.loc.to(torch.float64)-q.loc.to(torch.float64))/q.scale.to(torch.float64)).pow(2)return0.5*(var_ratio+t1-1-var_ratio.log())
[docs]defkl_divergence(obs,rnn_states,action,masks,available_actions,active_masks,new_actor,old_actor,):"""KL divergence between two distributions."""_,_,new_dist=new_actor.evaluate_actions(obs,rnn_states,action,masks,available_actions,active_masks)withtorch.no_grad():_,_,old_dist=old_actor.evaluate_actions(obs,rnn_states,action,masks,available_actions,active_masks)ifnew_dist.__class__.__name__=="FixedCategorical":# discrete actionnew_logits=new_dist.logitsold_logits=old_dist.logitskl=kl_approx(old_logits,new_logits)else:# continuous actionkl=_kl_normal_normal(old_dist,new_dist)iflen(kl.shape)>1:kl=kl.sum(1,keepdim=True)returnkl
# pylint: disable-next=invalid-name
[docs]defconjugate_gradient(actor,obs,rnn_states,action,masks,available_actions,active_masks,b,nsteps,device,residual_tol=1e-10,):"""Conjugate gradient algorithm. # refer to https://github.com/openai/baselines/blob/master/baselines/common/cg.py """x=torch.zeros(b.size()).to(device=device)r=b.clone()p=b.clone()rdotr=torch.dot(r,r)for_inrange(nsteps):_Avp=fisher_vector_product(actor,obs,rnn_states,action,masks,available_actions,active_masks,p)alpha=rdotr/torch.dot(p,_Avp)x+=alpha*pr-=alpha*_Avpnew_rdotr=torch.dot(r,r)betta=new_rdotr/rdotrp=r+betta*prdotr=new_rdotrifrdotr<residual_tol:breakreturnx