截断正态分布(Truncated Normal Distribution)
截断正态分布是对标准正态分布(Normal Distribution)进行限制(截断),将其值域限制在某个范围内,生成的随机变量只会取位于指定范围之内的值。
1. 截断正态分布的定义
正态分布的概率密度函数为:
f ( x ) = 1 2 π σ exp ( − ( x − μ ) 2 2 σ 2 ) , − ∞ < x < ∞ f(x) = \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(x-\mu)^2}{2\sigma^2}\right), \quad -\infty < x < \infty f(x)=2πσ1exp(−2σ2(x−μ)2),−∞<x<∞
- μ \mu μ:均值(mean)。
- σ \sigma σ:标准差(standard deviation)。
截断正态分布 是将正态分布限制在某个范围 [ a , b ] [a, b] [a,b] 内,其概率密度函数变为:
f T ( x ) = { 1 σ 2 π exp ( − ( x − μ ) 2 2 σ 2 ) F ( b ) − F ( a ) , if a ≤ x ≤ b 0 , otherwise f_T(x) = \begin{cases} \frac{1}{\sigma \sqrt{2\pi}} \frac{\exp\left(-\frac{(x-\mu)^2}{2\sigma^2}\right)}{F(b) - F(a)}, & \text{if } a \leq x \leq b \\ 0, & \text{otherwise} \end{cases} fT(x)=⎩ ⎨ ⎧σ2π1F(b)−F(a)exp(−2σ2(x−μ)2),0,if a≤x≤botherwise
其中:
- F ( x ) F(x) F(x):是标准正态分布的累积分布函数(CDF)。
- [ a , b ] [a, b] [a,b]:截断的范围。
通过这种方式,分布的值域被截断在指定范围内,超出范围的概率被重新分配到范围内。
2. 为什么要使用截断正态分布?
在机器学习和深度学习中,参数初始化对模型的训练至关重要。常用的初始化方法有 高斯分布(正态分布) 和 均匀分布。
但是,普通正态分布 的值域是无穷的,可能会产生极端值(太大或太小的值),这些极端值会导致训练不稳定或权重更新缓慢。
截断正态分布:
- 将权重初始化限制在合理范围内,避免过大的权重初始值。
- 让初始化的权重更加集中,确保模型更快收敛,且训练更稳定。
3. 截断正态分布的特点
-
值域有限:
- 正态分布被截断到区间 [ a , b ] [a, b] [a,b],不再产生超出该范围的值。
-
重新归一化:
- 在截断的范围内,对概率密度函数进行重新归一化,确保总概率为 1。
-
保留正态分布的形状:
- 在区间 [ a , b ] [a, b] [a,b] 内,概率密度函数仍然符合正态分布的形状。
-
避免极端值:
- 截断正态分布通过限制范围,有效防止了极端值的产生。
4. 截断正态分布与标准正态分布的区别
特性 | 标准正态分布 | 截断正态分布 |
---|---|---|
值域 | 无穷范围 ( − ∞ , + ∞ ) (-\infty, +\infty) (−∞,+∞) | 有限范围 [ a , b ] [a, b] [a,b] |
极端值概率 | 存在较大的极端值概率 | 极端值被截断,概率重新分布 |
归一化 | 无需重新归一化 | 需要在范围内重新归一化 |
用途 | 适合一般场景 | 适合需要限制初始值范围的场景 |
5. 在深度学习中的应用
截断正态分布常用于权重初始化,特别是在 Transformer、ViT 和其他深度学习模型中。它可以防止模型初始权重过大,导致训练不稳定。
PyTorch 中的截断正态分布
torch.nn.init
包中没有直接的截断正态分布实现,但 timm
库中提供了 trunc_normal_
方法来实现截断正态分布。
代码示例:
import torch
from timm.models.layers import trunc_normal_
# 创建一个空张量
tensor = torch.empty(3, 3)
# 使用截断正态分布初始化张量
trunc_normal_(tensor, mean=0.0, std=0.02, a=-0.04, b=0.04)
print(tensor)
参数说明:
tensor
:要初始化的张量。mean
:均值。std
:标准差。a, b
:截断范围,表示取值在 [ a , b ] [a, b] [a,b] 内。
输出:
- 张量中的值将被初始化为截断在 [ − 0.04 , 0.04 ] [-0.04, 0.04] [−0.04,0.04] 范围内的正态分布随机值。
6. 实现原理
截断正态分布的实现原理是:
- 从正态分布中采样:
- 生成正态分布的随机数。
- 判断是否在范围内:
- 如果随机数不在 [ a , b ] [a, b] [a,b] 范围内,则重新采样,直到满足条件。
- 重新归一化:
- 保证概率密度函数在区间 [ a , b ] [a, b] [a,b] 内的总和为 1。
7. 直观理解
- 如果你有一个标准正态分布,但你不想要极端值(例如,大于 2 σ 2\sigma 2σ 的值)。
- 你可以将分布截断在 [ − 2 σ , 2 σ ] [-2\sigma, 2\sigma] [−2σ,2σ] 范围内,这样就保证了所有的数值都在这个范围内。
图示如下:
正态分布 (未截断) 截断正态分布
| |
| ***** | *****
| **** **** | **** ****
| *** *** | *** ***
| ** ** | ** **
| ** ** | ** **
---|------------------------- ----|--------------------------