TRPO核心代码及注释

TRPO核心代码及注释

TRPO颇难,不如PPO直观,并且其中很多数学知识不容易理解,还需努力
参考网址附在代码注释里边了,感谢各位blog作者的努力

import numpy as np
from hparams import HyperParams as hp
from model import Actor
import torch
import math


def get_action(mu, std):
    ''' 得到一个动作 
        输入是Actor Network的输出(mu,std)
        输出是一个数
    '''
    action = torch.normal(mu, std)
    action = action.data.numpy()
    return action


def log_density(x, mu, std, logstd):
    ''' 
        得到一个概率密度函数的log值
        输入是高斯分布的均值mu和方差std
        还有std的log值(log其实就是ln,以e为底数)
        mu std 定义了高斯分布的概率密度函数F(x)
        给定概率密度函数F(x)里边的x
        就能算出取值在x附近的概率,【一定是附近】
        然后再取log
    '''
    var = std.pow(2)
    # 实际上就是高斯分布的概率密度函数的log值
    log_density = -(x - mu).pow(2) / (2 * var) \
                  - 0.5 * math.log(2 * math.pi) - logstd
    return log_density.sum(1, keepdim=True)


def flat_grad(grads):
    '''
        把一个输入向量grads展平
        展成一维向量
    '''
    grad_flatten = []
    for grad in grads:
        grad_flatten.append(grad.view(-1))
    grad_flatten = torch.cat(grad_flatten)
    return grad_flatten


def flat_hessian(hessians):
    '''
        把一个矩阵展平
        展成一维向量
    '''
    hessians_flatten = []
    for hessian in hessians:
        hessians_flatten.append(hessian.contiguous().view(-1))
    hessians_flatten = torch.cat(hessians_flatten).data
    return hessians_flatten


def flat_params(model):
    '''
        把模型里边的参数展平
    '''
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))
    params_flatten = torch.cat(params)
    return params_flatten


def update_model(model, new_params):
    index = 0
    for params in model.parameters():
        params_length = len(params.view(-1))
        new_param = new_params[index: index + params_length]
        new_param = new_param.view(params.size())
        params.data.copy_(new_param)
        index += params_length


def kl_divergence(new_actor, old_actor, states):
    '''
        计算新旧两个actor输出策略之间的KL散度
    '''
    # 新的actor
    mu, std, logstd = new_actor(torch.Tensor(states))
    # 旧的actor
    mu_old, std_old, logstd_old = old_actor(torch.Tensor(states))
    # 从计算图剥离
    mu_old 
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值