"""TRPO utility functions."""
import torch
[docs]
def flat_grad(grads):
"""Flatten the gradients."""
grad_flatten = []
for grad in grads:
if grad is None:
continue
grad_flatten.append(grad.view(-1))
grad_flatten = torch.cat(grad_flatten)
return grad_flatten
[docs]
def flat_hessian(hessians):
"""Flatten the hessians."""
hessians_flatten = []
for hessian in hessians:
if hessian is None:
continue
hessians_flatten.append(hessian.contiguous().view(-1))
hessians_flatten = torch.cat(hessians_flatten).data
return hessians_flatten
[docs]
def flat_params(model):
"""Flatten the parameters."""
params = []
for param in model.parameters():
params.append(param.data.view(-1))
params_flatten = torch.cat(params)
return params_flatten
[docs]
def update_model(model, new_params):
"""Update the model parameters."""
index = 0
for params in model.parameters():
params_length = len(params.view(-1))
new_param = new_params[index : index + params_length]
new_param = new_param.view(params.size())
params.data.copy_(new_param)
index += params_length
[docs]
def kl_approx(p, q):
"""KL divergence between two distributions."""
r = torch.exp(q - p)
kl = r - 1 - q + p
return kl
def _kl_normal_normal(p, q):
"""KL divergence between two normal distributions.
adapted from https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence
"""
var_ratio = (p.scale.to(torch.float64) / q.scale.to(torch.float64)).pow(2)
t1 = (
(p.loc.to(torch.float64) - q.loc.to(torch.float64)) / q.scale.to(torch.float64)
).pow(2)
return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
[docs]
def kl_divergence(
obs,
rnn_states,
action,
masks,
available_actions,
active_masks,
new_actor,
old_actor,
):
"""KL divergence between two distributions."""
_, _, new_dist = new_actor.evaluate_actions(
obs, rnn_states, action, masks, available_actions, active_masks
)
with torch.no_grad():
_, _, old_dist = old_actor.evaluate_actions(
obs, rnn_states, action, masks, available_actions, active_masks
)
if new_dist.__class__.__name__ == "FixedCategorical": # discrete action
new_logits = new_dist.logits
old_logits = old_dist.logits
kl = kl_approx(old_logits, new_logits)
else: # continuous action
kl = _kl_normal_normal(old_dist, new_dist)
if len(kl.shape) > 1:
kl = kl.sum(1, keepdim=True)
return kl
# pylint: disable-next=invalid-name
[docs]
def conjugate_gradient(
actor,
obs,
rnn_states,
action,
masks,
available_actions,
active_masks,
b,
nsteps,
device,
residual_tol=1e-10,
):
"""Conjugate gradient algorithm.
# refer to https://github.com/openai/baselines/blob/master/baselines/common/cg.py
"""
x = torch.zeros(b.size()).to(device=device)
r = b.clone()
p = b.clone()
rdotr = torch.dot(r, r)
for _ in range(nsteps):
_Avp = fisher_vector_product(
actor, obs, rnn_states, action, masks, available_actions, active_masks, p
)
alpha = rdotr / torch.dot(p, _Avp)
x += alpha * p
r -= alpha * _Avp
new_rdotr = torch.dot(r, r)
betta = new_rdotr / rdotr
p = r + betta * p
rdotr = new_rdotr
if rdotr < residual_tol:
break
return x
[docs]
def fisher_vector_product(
actor, obs, rnn_states, action, masks, available_actions, active_masks, p
):
"""Fisher vector product."""
with torch.backends.cudnn.flags(enabled=False):
p.detach()
kl = kl_divergence(
obs,
rnn_states,
action,
masks,
available_actions,
active_masks,
new_actor=actor,
old_actor=actor,
)
kl = kl.mean()
kl_grad = torch.autograd.grad(
kl, actor.parameters(), create_graph=True, allow_unused=True
)
kl_grad = flat_grad(kl_grad) # check kl_grad == 0
kl_grad_p = (kl_grad * p).sum()
kl_hessian_p = torch.autograd.grad(
kl_grad_p, actor.parameters(), allow_unused=True
)
kl_hessian_p = flat_hessian(kl_hessian_p)
return kl_hessian_p + 0.1 * p