权重衰退(weight_decay)是通过限制参数值的选择范围来控制模型容量
前要:如何控制一个模型的容量?
1、减少模型参数
2、限制参数值的选择范围
目录
1、使用均方范数作为硬性限制(相对麻烦,不常用)
- 通常不限制偏移b
- 小的
意味着更强的正则项
2、使用均方范数作为柔性限制(较为常用)
- 对每个
,都可以找到
使得之前的目标函数等价于下面(引入一个“惩罚项”):
- 超参数
控制了正则性的重要程度
= 0:无作用
=
,
3、引入惩罚项对最优解的影响
Ps:通过加入惩罚项来降低模型的复杂度,将原来优化的目标作出更改(相当于无条件极值变有条件极值),在一定程度上使用更少的参数,并使得模型的泛化能力提高(过多的参数可能会导致过拟合,过多参数会使运算量提高)
4、参数更新法则
- 计算梯度
- 时间 t 更新参数
- 通常
,在深度学习中通常叫做权重衰退
5、总结
- 权重衰退通过L2正则项使得模型参数不会过大,从而控制模型复杂度
- 正则项权重是控制模型复杂度的超参数
附:代码实现
# Import related library
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
# Generate data set
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)
# Initialize the model parameters
def init_params():
w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
return [w, b]
# Define L2 norm penalties
def l2_penalty(w):
return torch.sum(w.pow(2)) / 2
# Define the training code implementation
def train(lambd):
w, b = init_params()
net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
num_epochs, lr = 100, 0.003
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
xlim=[5, num_epochs], legend=['train', 'test'])
for epoch in range(num_epochs):
for X, y in train_iter:
with tf.GradientTape() as tape:
# 增加了L2范数惩罚项,
# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
l = loss(net(X), y) + lambd * l2_penalty(w)
grads = tape.gradient(l, [w, b])
d2l.sgd([w, b], grads, lr, batch_size)
if (epoch + 1) % 5 == 0:
animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
print('w的L2范数是:', tf.norm(w).numpy())
# Ignore the regularization direct training
train(lambd=0)
# Usage weight attenuation
train(lambd=3)
Q&A:
1、权重衰退的值一般设置多少较好?
按经验e^-3,e^-4,0.001等
2、为什么要把w往小里拉?如果最优解的w就是比较大的数,那么权重衰退是不是会有反作用?
因为数据集有噪音,算出来的w肯定会一定程度的偏大,需要lambd来拉回