Pytorch 数值稳定性,模型初始化
这一节内容有很多数学推导,大家可以多看看李沐老师的视频和教材理解理解。我摊牌了,这一章我没完全听懂。就大概记了下能大概听懂的内容,准备以后的学习中进一步加深对本节课的理解。
1. 数值稳定性
1.1 神经网络的梯度
考虑如下有
d
d
d 层神经网络:
h
t
=
f
t
(
h
t
−
1
)
and
y
=
ℓ
∘
f
d
∘
…
∘
f
1
(
x
)
\mathbf{h}^{t}=f_{t}\left(\mathbf{h}^{t-1}\right) \quad \text { and } \quad y=\ell \circ f_{d} \circ \ldots \circ f_{1}(\mathbf{x})
ht=ft(ht−1) and y=ℓ∘fd∘…∘f1(x)
计算损失
ℓ
\ell
ℓ 关于
W
t
\mathbf{W}^{t}
Wt 的 的梯度:
∂
l
∂
W
t
=
∂
l
∂
h
d
∂
h
d
∂
h
d
−
1
…
∂
h
t
+
1
∂
h
t
∂
h
t
∂
W
t
\frac{\partial l}{\partial \mathbf{W}^{t}}=\frac{\partial l}{\partial \mathbf{h}^{d}} \frac{\partial \mathbf{h}^{d}}{\partial \mathbf{h}^{d-1}} \ldots \frac{\partial \mathbf{h}^{t+1}}{\partial \mathbf{h}^{t}} \frac{\partial \mathbf{h}^{t}}{\partial \mathbf{W}^{t}}
∂Wt∂l=∂hd∂l∂hd−1∂hd…∂ht∂ht+1∂Wt∂ht
1.2 梯度爆炸和梯度消失
当每层梯度都是大于 1 的情况下,层数变多,最后得到的数值会越来越大。
当每层梯度都是小于 1 的情况下,层数变多,最后得到的数值会越来越小。
1.3 例子:MLP
加入如下 MLP (为了简便先不考虑偏置
b
b
b):
f
t
(
h
t
−
1
)
=
σ
(
W
t
h
t
−
1
)
∂
h
t
∂
h
t
−
1
=
diag
(
σ
′
(
W
t
h
t
−
1
)
)
(
W
t
)
T
∏
i
=
t
d
−
1
∂
h
i
+
1
∂
h
i
=
∏
i
=
t
d
−
1
diag
(
σ
′
(
W
i
h
i
−
1
)
)
(
W
i
)
T
f_{t}\left(\mathbf{h}^{t-1}\right)=\sigma\left(\mathbf{W}^{t} \mathbf{h}^{t-1}\right) \\ \frac{\partial \mathbf{h}^{t}}{\partial \mathbf{h}^{t-1}}=\operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{t} \mathbf{h}^{t-1}\right)\right)\left(W^{t}\right)^{T} \\ \prod_{i=t}^{d-1} \frac{\partial \mathbf{h}^{i+1}}{\partial \mathbf{h}^{i}}=\prod_{i=t}^{d-1} \operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{i} \mathbf{h}^{i-1}\right)\right)\left(W^{i}\right)^{T}
ft(ht−1)=σ(Wtht−1)∂ht−1∂ht=diag(σ′(Wtht−1))(Wt)Ti=t∏d−1∂hi∂hi+1=i=t∏d−1diag(σ′(Wihi−1))(Wi)T
其中
σ
\sigma
σ 是激活函数,
σ
′
\sigma^{\prime}
σ′ 是
σ
\sigma
σ 的导函数。
1.3.1 梯度爆炸
使用
R
e
L
U
ReLU
ReLU 作为激活函数:
σ
(
x
)
=
max
(
0
,
x
)
and
σ
′
(
x
)
=
{
1
if
x
>
0
0
otherwise
\sigma(x)=\max (0, x) \quad \text { and } \quad \sigma^{\prime}(x)= \begin{cases}1 & \text { if } x>0 \\ 0 & \text { otherwise }\end{cases}
σ(x)=max(0,x) and σ′(x)={10 if x>0 otherwise
∏
i
=
t
d
−
1
∂
h
i
+
1
∂
h
i
=
∏
i
=
t
d
−
1
diag
(
σ
′
(
W
i
h
i
−
1
)
)
(
W
i
)
T
\prod_{i=t}^{d-1} \frac{\partial \mathbf{h}^{i+1}}{\partial \mathbf{h}^{i}}=\prod_{i=t}^{d-1} \operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{i} \mathbf{h}^{i-1}\right)\right)\left(W^{i}\right)^{T}
∏i=td−1∂hi∂hi+1=∏i=td−1diag(σ′(Wihi−1))(Wi)T 的一些元素会来自于
∏
i
=
t
d
−
1
(
W
i
)
T
\prod_{i=t}^{d-1}\left(W^{i}\right)^{T}
∏i=td−1(Wi)T。如果
d
−
t
d-t
d−t 很大,得到的数值将会很大。
梯度爆炸的问题:
- 值超出值域
- 对学习率敏感
- 若学习率太大->大参数值->更大的梯度
- 若学习率太小->训练无进展
- 我们可能需要在驯良过程中不断调整学习率
1.3.2 梯度消失
使用
s
i
g
m
o
i
d
sigmoid
sigmoid 作为激活函数:
σ
(
x
)
=
1
1
+
e
−
x
σ
′
(
x
)
=
σ
(
x
)
(
1
−
σ
(
x
)
)
\sigma(x)=\frac{1}{1+e^{-x}} \quad \sigma^{\prime}(x)=\sigma(x)(1-\sigma(x))
σ(x)=1+e−x1σ′(x)=σ(x)(1−σ(x))
∏
i
=
t
d
−
1
∂
h
i
+
1
∂
h
i
=
∏
i
=
t
d
−
1
diag
(
σ
′
(
W
i
h
i
−
1
)
)
(
W
i
)
T
\prod_{i=t}^{d-1} \frac{\partial \mathbf{h}^{i+1}}{\partial \mathbf{h}^{i}}=\prod_{i=t}^{d-1} \operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{i} \mathbf{h}^{i-1}\right)\right)\left(W^{i}\right)^{T}
∏i=td−1∂hi∂hi+1=∏i=td−1diag(σ′(Wihi−1))(Wi)T 的元素值是
d
−
t
d-t
d−t 个小数值的乘积。
梯度消失的问题:
- 梯度值变成 0 0 0
- 训练没有进展
- 不管如何选择学习率
- 对于底部层尤为严重
- 仅仅顶部层训练的较好
- 无法让神经网络更深
2. 模型初始化
2.1 让训练更加稳定
- 目标:让梯度值在合理的范围内
- 例如 [ 1 e − 6 , 1 e 3 ] [1e-6, 1e3] [1e−6,1e3]
- 将乘法变加法
- ResNet,LSTM
- 归一化
- 梯度归一化,梯度裁剪
- 合理的权重初始和激活函数
2.2 让每层的方差是一个常数
- 将每层的输出和梯度都看作随机变量
- 让它们的均值和方差都保持一致
2.3 权重初始化
- 在合理值区间里随机初始参数
- 训练开始的时候更容易有数值不稳定
- 远离最优解的地方损失函数表面可能很复杂
- 最优解附近表面会比较平
- 使用 N ( 0 , 0.01 ) N(0, 0.01) N(0,0.01) 来初始可能对小网络没问题,但不能保证深度神经网络。
2.4 Xavier 初始化
Xavier 初始化从均值为零,方差 σ 2 = 2 n i n + n o u t \sigma^2 = \frac{2}{n_\mathrm{in} + n_\mathrm{out}} σ2=nin+nout2 的高斯分布中采样权重。 我们也可以利用 Xavier 的直觉来选择从均匀分布中抽取权重时的方差。 注意均匀分布 U ( − a , a ) U(-a, a) U(−a,a) 的方差为 a 2 3 \frac{a^2}{3} 3a2。 将 a 2 3 \frac{a^2}{3} 3a2 代入到 σ 2 \sigma^2 σ2 的条件中,将得到初始化值域:
U ( − 6 n i n + n o u t , 6 n i n + n o u t ) . U\left(-\sqrt{\frac{6}{n_\mathrm{in} + n_\mathrm{out}}}, \sqrt{\frac{6}{n_\mathrm{in} + n_\mathrm{out}}}\right). U(−nin+nout6,nin+nout6).
尽管在上述数学推理中,“不存在非线性”的假设在神经网络中很容易被违反, 但 Xavier 初始化方法在实践中被证明是有效的。
Xavier 初始化表明,对于每一层,输出的方差不受输入数量的影响,任何梯度的方差不受输出数量的影响。