Truncated normal distribution - Wikipedia
Normal Distribution 称为正态分布,也称为高斯分布,Truncated Normal Distribution一般翻译为截断正态分布,也有称为截尾正态分布。
截断正态分布是截断分布(Truncated Distribution)的一种,那么截断分布是什么?截断分布是指,限制变量xx 取值范围(scope)的一种分布。例如,限制x取值在0到50之间,即{0<x<50}。因此,根据限制条件的不同,截断分布可以分为:
- 2.1 限制取值上限,例如,负无穷<x<50
- 2.2 限制取值下限,例如,0<x<正无穷
- 2.3 上限下限取值都限制,例如,0<x<50
正态分布则可视为不进行任何截断的截断正态分布,也即自变量的取值为负无穷到正无穷;
1. 概率密度函数
假设 X 原来服从正太分布,那么限制 x 的取值在(a,b)范围内之后,X 的概率密度函数,可以用下面公式计算:
f
(
x
;
μ
,
σ
,
a
,
b
)
=
1
σ
ϕ
(
x
−
μ
σ
)
Φ
(
b
−
μ
σ
)
−
Φ
(
a
−
μ
σ
)
f(x ; \mu, \sigma, a, b)=\frac{\frac{1}{\sigma} \phi\left(\frac{x-\mu}{\sigma}\right)}{\Phi\left(\frac{b-\mu}{\sigma}\right)-\Phi\left(\frac{a-\mu}{\sigma}\right)}
f(x;μ,σ,a,b)=Φ(σb−μ)−Φ(σa−μ)σ1ϕ(σx−μ)
简写为:
⇒
f
(
x
)
F
(
b
)
−
F
(
a
)
⋅
I
(
a
<
x
<
b
)
\Rightarrow \frac{f(x)}{F(b)-F(a)} \cdot I(a<x<b)
⇒F(b)−F(a)f(x)⋅I(a<x<b)
- 其中 ϕ(⋅):均值为 0,方差为 1 的标准正态分布;
ϕ ( ξ ) = 1 2 π exp ( − 1 2 ξ 2 ) \phi(\xi)=\frac{1}{\sqrt{2 \pi}} \exp \left(-\frac{1}{2} \xi^{2}\right) ϕ(ξ)=2π1exp(−21ξ2) - Φ(⋅) 为标准正态分布CDF;
b → ∞ , ⇒ Φ ( b − μ σ ) = 1 a → − ∞ ⇒ Φ ( a − μ σ ) = 0 \begin{array}{l} b \rightarrow \infty, \Rightarrow \Phi\left(\frac{b-\mu}{\sigma}\right)=1 \\ a \rightarrow-\infty \Rightarrow \Phi\left(\frac{a-\mu}{\sigma}\right)=0 \end{array} b→∞,⇒Φ(σb−μ)=1a→−∞⇒Φ(σa−μ)=0
参考:https://blog.csdn.net/lanchunhui/article/details/61623189
2. 截断正态分布的pytorch实现
def truncated_normal_(self,tensor,mean=0,std=0.09):
with torch.no_grad():
size = tensor.shape
tmp = tensor.new_empty(size+(4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor
参考:https://zhuanlan.zhihu.com/p/83609874
截断到两个标准差的实现(One simple approximated implementation for truncated_normal to 2*std.):
def truncated_normal_(tensor, mean=0, std=1):
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
测试的话,可以使用
import torch
from scipy.stats import truncnorm
import matplotlib.pyplot as plt
def truncated_normal_(tensor, mean=0, std=1):
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor
fig, ax = plt.subplots(1, 1)
def test_truncnorm():
a, b = -2, 2
size = 1000000
r = truncnorm.rvs(a, b, size=size)
ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)
tensor = torch.zeros(size)
utils.truncated_normal_(tensor)
r = tensor.numpy()
ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)
ax.legend(loc='best', frameon=False)
plt.show()
if __name__ == '__main__':
test_truncnorm()
Here is a simpler way to sample values from truncated normal distribution.
from scipy.stats import truncnorm
import torch
def truncated_normal(size, threshold=1):
values = truncnorm.rvs(-threshold, threshold, size=size)
return values
# usage example
x= truncnorm([10, 20], threshold=1) # sample 10x20 sized tensor
x = torch.from_numpy(x).cuda()
推荐参考:https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/20