正则化之weight decay
1、正则化与偏差-方差分解
机器学习中的误差可以看作噪声+偏差+方差:
- 噪声:在当前任务上任何学习算法所能达到的期望泛化误差的下界,无法通过优化模型来减小
- 偏差:指一个模型在不同训练集上的平均性能和最优模型的差异,度量了学习算法的期望预测与真实结果的偏离程度,即刻画了学习算法本身的拟合能力
- 方差:指一个模型在不同训练集上的差异,度量了同样大小的训练集的变动所导致的学习性能的变化,即刻画了数据 扰动所造成的影响,可以用来衡量一个模型是否容易过拟合
正则化就是用来降低方差,从而减小过拟合的方法,如果记录损失函数为:
L
o
s
s
=
f
(
y
^
,
y
)
Loss = f (\hat{y}, y)
Loss=f(y^,y),则对于样本而言,代价函数为
c
o
s
t
=
1
N
∑
i
N
f
(
y
^
i
,
y
i
)
cost=\frac{1}{N} \sum_{i}^{N} f\left(\hat{y}_{i}, y_{i}\right)
cost=N1∑iNf(y^i,yi),目标函数为:
O
b
j
=
C
o
s
t
+
R
e
g
u
l
a
r
i
z
a
t
i
o
n
Obj = Cost + Regularization
Obj=Cost+Regularization
这里的
R
e
g
u
l
a
r
i
z
a
t
i
o
n
Regularization
Regularization就是正则项,用于减小方差的策略,常见的有两种正则化方式:
- L 1 L1 L1正则: Σ i N ∣ w i ∣ \Sigma_{i}^{N}\left|w_{i}\right| ΣiN∣wi∣
- L 2 L2 L2正则: ∑ i N w i 2 \sum_{i}^{N} w_{i}^{2} ∑iNwi2
2、Pytorch
中的L2正则项——weight decay
L
2
L2
L2正则项本质上相当于是权值衰减,这是因为目标函数为
O
b
j
=
C
o
s
t
+
R
e
g
u
l
a
r
i
z
a
t
i
o
n
=
L
o
s
s
+
λ
2
∑
i
N
w
i
2
Obj = Cost + Regularization = Loss + \frac{\lambda}{2} \sum_{i}^{N} w_{i}^{2}
Obj=Cost+Regularization=Loss+2λ∑iNwi2,在梯度下降公式中:
w
i
+
1
=
w
i
−
∂
o
b
j
∂
w
i
=
w
i
−
(
∂
L
o
s
s
∂
w
i
+
λ
w
i
)
=
(
1
−
λ
)
w
i
−
∂
L
o
s
s
∂
w
i
\begin{aligned} w_{i+1}&=w_{i}-\frac{\partial o b j}{\partial w_{i}}\\ &=w_{i}-(\frac{\partial L o s s}{\partial w_{i}}+\lambda w_i)\\ & = (1-\lambda)w_i - \frac{\partial L o s s}{\partial w_{i}} \end{aligned}
wi+1=wi−∂wi∂obj=wi−(∂wi∂Loss+λwi)=(1−λ)wi−∂wi∂Loss
由于这里的正则化系数
λ
\lambda
λ是一个介于
0
0
0到
1
1
1之间数,因此可以看出,与不加正则项的梯度下降公式——
w
i
+
1
=
w
i
−
∂
L
o
s
s
∂
w
i
w_{i+1}= w_i - \frac{\partial L o s s}{\partial w_{i}}
wi+1=wi−∂wi∂Loss相比,相当于是做了一个权值的下降。
Pytorch
中的 weight decay 是在优化器中实现的,在优化器中加入参数weight_decay=
即可,例如下面的两个随机梯度优化器,一个是没有加入正则项,一个加入了正则项,区别仅仅在于是否设置了参数weight_decay
的值:
optim_normal = torch.optim.SGD(net_normal.parameters(), lr=lr_init, momentum=0.9)
optim_wdecay = torch.optim.SGD(net_weight_decay.parameters(), lr=lr_init, momentum=0.9, weight_decay=1e-2)
最终,我们可以得到通过这两个优化器训练得到的模型:
红色曲线为没有 weight decay 的结果,蓝色虚线为加入 weight decay 的训练结果,可以看到加入后能够非常有效的缓解过拟合现象。
下图为第二层的权值柱状图,左边为加入正则项,右边为没有加正则项,可以看出,左边的权值确实有一个递减的趋势,而右边几乎是保持不变的状态。