torch.distributions 详解

distributions包含可参数化的概率分布和采样函数,这允许构造用于优化的随机计算图和随机梯度估计器。

通常,不可能直接通过随机样本反向传播。但是,有两种方法可以创建可以反向传播的代理函数,即得分函数估计器/似然比函数估计器/REINFORCE和pathwise derivative估计器。REINFORCE通常被视为强化学习中策略梯度方法的基础,并且pathwise derivative估计器常见于变分自动编码器中的重新参数化技巧。得分函数仅需要样本的值 f ( x ) f(x) f(x),pathwise derivative需要导数 f ′ ( x ) f'(x) f(x),接下来的部分将在一个强化学习示例中讨论这两个问题。

得分函数

当概率密度函数相对于其参数可微分时,我们只需要sample()log_prob()来实现REINFORCE:

△ θ = α r ▽ θ l o g π θ ( a ∣ s ) \triangle\theta = \alpha r \triangledown_{\theta} \mathrm{log} \pi_{\theta}(a | s) θ=αrθlogπθ(as)

其中, θ \theta θ是参数, α \alpha α是学习速率, r r r是奖励, π θ ( a ∣ s ) \pi_{\theta}(a | s) πθ(as)是在给定策略 π θ \pi_{\theta} πθ下在状态 s s s执行动作 a a a的概率。

在实践中,我们将从网络输出中采样一个动作,将这个动作应用到环境中,然后使用log_prob构造一个等效的损失函数。请注意,我们使用负数是因为优化器使用梯度下降,而上面的规则假设梯度上升。

有了确定的策略,REINFORCE的实现代码如下:

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

Pathwise derivative

实现这些随机/策略梯度的另一种方法是使用来自rsample()方法的重新参数化技巧,其中参数化随机变量可以通过无参数随机变量的参数确定性函数构造。因此,重新参数化的样本变得可微分,实现Pathwise derivative的代码如下:

param  = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action) # Assuming that reward is differentiable
loss = -reward
loss.backward()

Normal

class torch.distributions.normal.Normal(loc, scale, validate_args=None)

基类:torch.distributions.exp_family.ExponentialFamily

创建由locscale参数化的正态分布(高斯分布)。

e.g.

>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample() # normally distributed with loc=0 and scale=1
tensor([0.1046])

参数:

  • loc (float or Tensor) — 均值(也被称为mu)
  • scale (float or Tensor) — 标准差 (也被称为sigma)

Reference:
https://github.com/apachecn/pytorch-doc-zh/blob/master/docs/1.0/distributions.md

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值