Reparameterization Trick 及 Inverse Transform Sampling
- Reparameterization Trick 请看《von Mises-Fisher Distribution (代码解析) - 1.3 has_rsample=True》
- Inverse Transform Sampling 请看《von Mises-Fisher Distribution - 5.1 p = 3 时的 Inverse Transform Sampling》, 或者 Inverse transform sampling - Wikipedia.
Distribution 采样对参数的梯度
我们知道 Reparameterization Trick 是用简单的分布配合一些参数转化为更复杂的分布, 如标准正态分布 N ( 0 , 1 ) N(0,1) N(0,1) 配合参数 μ , σ \mu, \sigma μ,σ, 在从 N ( 0 , 1 ) N(0,1) N(0,1) 采样得到样本 x x x 后, 计算 y = x ∗ σ + μ y = x*\sigma + \mu y=x∗σ+μ 就得到对分布 N ( μ , σ ) N(\mu, \sigma) N(μ,σ) 的采样.
指数分布的累积分布函数为 F ( x ) = 1 − e − λ x , x > 0 F(x) = 1 - e^{-\lambda x}, ~ x>0 F(x)=1−e−λx, x>0, 根据逆变换采样方法, u = F ( x ) ∼ U n i f o r m ( 0 , 1 ) u = F(x) \sim Uniform(0,1) u=F(x)∼Uniform(0,1), 从均匀分布中采样一个 u u u, 再计算 x = F − 1 ( u ) = − 1 λ l n ( 1 − u ) x = F^{-1}(u) = -\frac{1}{\lambda}ln(1-u) x=F−1(u)=−λ1ln(1−u) 就得到对指数分布的采样.
无论是简单分布的 Transform, 还是 Inverse Transform, 可以看到, 对于带参数的概率分布, 是可以计算样本 x x x 对参数 ( μ , σ ) (\mu, \sigma) (μ,σ) 或 λ \lambda λ 的梯度的, 以便在模型优化过程中对概率分布的参数进行调整.
torch.distribution
模块实现了大多数常见的分布. 对于复杂的分布, 你可以继承基类 Distribution
自己实现. 如果参数设置为 requires_grad=True
:
import torch
from torch import distributions, autograd
rate = torch.tensor(5.0, requires_grad=True)
exp_distri = distributions.Exponential(rate)
x = exp_distri.rsample()
print(x)
g = autograd.grad(x, rate)
print(g)
u = exp_distri.cdf(x)
print(torch.log(1 - u) / rate ** 2)
output:
样本: tensor(0.5938, grad_fn=<DivBackward0>)
rate 的梯度: (tensor(-0.1188),)
根据 Inverse Transform 计算的 rate 的梯度: tensor(-0.1188, grad_fn=<DivBackward0>)
可见, PyTorch 中指数分布的采样应该是根据 Inverse Transform 进行采样的. 但无论怎样, 都是一个模式:
所以, 不必纠结梯度是怎样的, 它的计算要根据简单分布的 Sampling 和 Transform 参数两部分得到.