梯度裁剪
梯度用数学式子表达即为偏导数,梯度爆炸就等价于偏导数很大。
梯度裁剪应用的前提
网络在训练过程中出现了梯度爆炸(loss激增or直接NaN)的现象。梯度爆炸问题一般会随着网络层数的增加而变得越来越明显。
根据weight更新的计算公式: w i = w i − α ∂ N ( Θ ) ∂ w i w_i = w_i - \alpha \frac{\partial N(\Theta)}{\partial w_i} wi=wi−α∂wi∂N(Θ), 其中 ∂ N ( Θ ) ∂ w i \frac{\partial N(\Theta)}{\partial w_i} ∂wi∂N(Θ)为复合函数求导。若每一层的偏导数都大于1,且层数又多,则容易发生梯度爆炸现象。
解决梯度爆炸问题的办法
- 将学习率 α \alpha α设置得小一点。但是如果使用重启余弦退火算法,就不可避免地会在重启时学习率跳到很大,学习率调到太小也不利于继续寻找最优解。而
- 使用梯度裁剪,即控制 ∂ N ( Θ ) ∂ w i \frac{\partial N(\Theta)}{\partial w_i} ∂wi∂N(Θ)不能超过某一阈值。
PyTorch使用梯度裁剪
梯度裁剪的使用位置在loss.backward()
得到梯度值之后,在使用optimizer.step()
进行权重值更新之前。
固定阈值裁剪(对象:for gradient of each parameter)
约束每个权重参数 w i w_i wi的梯度值 ∂ N ( Θ ) ∂ w i ∈ [ − x , x ] \frac{\partial N(\Theta)}{\partial w_i}\in [-x, x] ∂wi∂N(Θ)∈[−x,x].
例子:
import torch.nn as nn
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(model.parameters(), clip_value=x)
optimizer.step()
优点:简单粗暴;
缺点:因为是element-wise的操作,很难找到合适的阈值。
根据参数的范数来衡量(对象:for all gradients)
对所有权重参数 { w i } i = 1 K \{w_i\}_{i=1}^{K} {wi}i=1K的总体范数进行约束。所有权重参数的总体范数为 ∥ g ∥ n = [ ( ∂ N ( Θ ) ∂ w 1 ) n + ⋯ + ( ∂ N ( Θ ) ∂ w K ) n ] 1 n \Vert \mathbf{g}\Vert_n = {[(\frac{\partial N(\Theta)}{\partial w_1})^n + \cdots + (\frac{\partial N(\Theta)}{\partial w_K})^n]}^{\frac{1}{n}} ∥g∥n=[(∂w1∂N(Θ))n+⋯+(∂wK∂N(Θ))n]n1,设置阈值 c c c。
当
∥
g
∥
n
≤
c
\Vert \mathbf{g}\Vert_n\leq c
∥g∥n≤c时,不做clip
的操作。
当 ∥ g ∥ n > c \Vert \mathbf{g}\Vert_n > c ∥g∥n>c,执行操作 g : = c ∥ g ∥ n ⋅ g \mathbf{g} := \frac{c}{\Vert \mathbf{g}\Vert_n} \cdot \mathbf{g} g:=∥g∥nc⋅g,最终 ∥ g ∥ n = c \Vert \mathbf{g}\Vert_n = c ∥g∥n=c。
所以,通过对范数进行约束,以达到 ∥ g ∥ n ≤ c \Vert \mathbf{g}\Vert_n\leq c ∥g∥n≤c的目的,这种对于梯度值整体做缩放的操作,比较容易调参。
例子:
import torch.nn as nn
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=c, norm_type=2)
optimizer.step()
根据经验, c c c值的选取应由大到小,依次为10, 5, 1, 0.1。如果都不行的话,则需观察梯度更新时的值来进行具体设置。
关于实时梯度值的观测
此处举例观察各层梯度值的二范数(应该写在loss.backward()
之后)中的最大值
max_grad_norm = 0
for param in model.parameters():
if param.grad is not None:
max_grad_norm = max(param.grad.data.norm(2), max_grad_norm)
print("the max value of gradient's L2 norm is {}".format(max_grad_norm))