def _1st_order_trpo(self, detached_policy_loss_vb, detached_policy_vb, detached_avg_policy_vb, detached_splitted_policy_vb=None):
on_policy = detached_splitted_policy_vb is None
# KL divergence k = \delta_{\phi_{\theta}} DKL[ \pi(|\phi_{\theta_a}) || \pi{|\phi_{\theta}}]
# kl_div_vb = F.kl_div(detached_policy_vb.log(), detached_avg_policy_vb, size_average=False) # NOTE: the built-in one does not work on batch
kl_div_vb = categorical_kl_div(detached_policy_vb, detached_avg_policy_vb)
# NOTE: k & g are wll w.r.t. the network output, which is detached_policy_vb
# NOTE: gradient from this part will not flow back into the model
# NOTE: that's why we are only using detached policy variables here
if on_policy:
k_vb = grad(outputs=kl_div_vb, inputs=detached_policy_vb, retain_graph=False, only_inputs=True)[0]
g_vb = grad(outputs=detached_policy_loss_vb, inputs=detached_policy_vb, retain_graph=False, only_inputs=True)[0]
else:
# NOTE NOTE NOTE !!!
# NOTE: here is why we cannot simply detach then split the policy_vb, but must split before detach
# NOTE: cos if we do that then the split cannot backtrace the grads computed in this later part of the graph
# NOTE: it would have no way to connect to the graphs in the model
k_vb = grad(outputs=(kl_div_vb.split(1, 0)), inputs=(detached_splitted_policy_vb), retain_graph=False, only_inputs=True)
g_vb = grad(outputs=(detached_policy_loss_vb.split(1, 0)), inputs=(detached_splitted_policy_vb), retain_graph=False, only_inputs=True)
k_vb = torch.cat(k_vb, 0)
g_vb = torch.cat(g_vb, 0)
kg_dot_vb = (k_vb * g_vb).sum(1, keepdim=True)
kk_dot_vb = (k_vb * k_vb).sum(1, keepdim=True)
z_star_vb = g_vb - ((kg_dot_vb - self.master.clip_1st_order_trpo) / kk_dot_vb).clamp(min=0) * k_vb
return z_star_vb