Truncated normal distribution()

Truncated normal distribution - Wikipedia

Normal Distribution 称为正态分布,也称为高斯分布,Truncated Normal Distribution一般翻译为截断正态分布,也有称为截尾正态分布。

截断正态分布是截断分布(Truncated Distribution)的一种,那么截断分布是什么?截断分布是指,限制变量xx 取值范围(scope)的一种分布。例如,限制x取值在0到50之间,即{0<x<50}。因此,根据限制条件的不同,截断分布可以分为:

  • 2.1 限制取值上限,例如,负无穷<x<50
  • 2.2 限制取值下限,例如,0<x<正无穷
  • 2.3 上限下限取值都限制,例如,0<x<50

正态分布则可视为不进行任何截断的截断正态分布,也即自变量的取值为负无穷到正无穷;

1. 概率密度函数

假设 X 原来服从正太分布,那么限制 x 的取值在(a,b)范围内之后,X 的概率密度函数,可以用下面公式计算:
f ( x ; μ , σ , a , b ) = 1 σ ϕ ( x − μ σ ) Φ ( b − μ σ ) − Φ ( a − μ σ ) f(x ; \mu, \sigma, a, b)=\frac{\frac{1}{\sigma} \phi\left(\frac{x-\mu}{\sigma}\right)}{\Phi\left(\frac{b-\mu}{\sigma}\right)-\Phi\left(\frac{a-\mu}{\sigma}\right)} f(x;μ,σ,a,b)=Φ(σbμ)Φ(σaμ)σ1ϕ(σxμ)
简写为:
⇒ f ( x ) F ( b ) − F ( a ) ⋅ I ( a < x < b ) \Rightarrow \frac{f(x)}{F(b)-F(a)} \cdot I(a<x<b) F(b)F(a)f(x)I(a<x<b)

  • 其中 ϕ(⋅):均值为 0,方差为 1 的标准正态分布;
    ϕ ( ξ ) = 1 2 π exp ⁡ ( − 1 2 ξ 2 ) \phi(\xi)=\frac{1}{\sqrt{2 \pi}} \exp \left(-\frac{1}{2} \xi^{2}\right) ϕ(ξ)=2π 1exp(21ξ2)
  • Φ(⋅) 为标准正态分布CDF;
    b → ∞ , ⇒ Φ ( b − μ σ ) = 1 a → − ∞ ⇒ Φ ( a − μ σ ) = 0 \begin{array}{l} b \rightarrow \infty, \Rightarrow \Phi\left(\frac{b-\mu}{\sigma}\right)=1 \\ a \rightarrow-\infty \Rightarrow \Phi\left(\frac{a-\mu}{\sigma}\right)=0 \end{array} b,Φ(σbμ)=1aΦ(σaμ)=0

参考:https://blog.csdn.net/lanchunhui/article/details/61623189

2. 截断正态分布的pytorch实现

def truncated_normal_(self,tensor,mean=0,std=0.09):
    with torch.no_grad():
        size = tensor.shape
        tmp = tensor.new_empty(size+(4,)).normal_()
        valid = (tmp < 2) & (tmp > -2)
        ind = valid.max(-1, keepdim=True)[1]
        tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
        tensor.data.mul_(std).add_(mean)
        return tensor

参考:https://zhuanlan.zhihu.com/p/83609874


截断到两个标准差的实现(One simple approximated implementation for truncated_normal to 2*std.):

def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)

测试的话,可以使用

import torch
from scipy.stats import truncnorm
import matplotlib.pyplot as plt

def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)
    return tensor



fig, ax = plt.subplots(1, 1)


def test_truncnorm():
    a, b = -2, 2
    size = 1000000
    r = truncnorm.rvs(a, b, size=size)
    ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)

    tensor = torch.zeros(size)
    utils.truncated_normal_(tensor)
    r = tensor.numpy()

    ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)
    ax.legend(loc='best', frameon=False)
    plt.show()


if __name__ == '__main__':
    test_truncnorm()


Here is a simpler way to sample values from truncated normal distribution.

from scipy.stats import truncnorm
import torch


def truncated_normal(size, threshold=1):
    values = truncnorm.rvs(-threshold, threshold, size=size)
    return values

# usage example
x= truncnorm([10, 20], threshold=1)   # sample 10x20 sized tensor
x = torch.from_numpy(x).cuda()

推荐参考:https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/20

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值