文章目录
一、过拟合、欠拟合概念
训练模型中经常出现的两类典型问题:
- 欠拟合:模型无法得到较低的训练误差
- 过拟合:模型的训练误差远小于它在测试数据集上的误差
在实践中,我们要尽可能同时应对欠拟合和过拟合。有很多因素可能导致这两种拟合问题,在这里我们重点讨论两个因素:模型复杂度和训练数据集大小。
- 模型复杂度
为了解释模型复杂度,我们以多项式函数拟合为例。给定一个由标量数据特征 x 和对应的标量标签 y 组成的训练数据集,多项式函数拟合的目标是找一个K阶多项式函数:
来近似 y。在上式中, w k w_k wk是模型的权重参数,b是偏差参数。与线性回归相同,多项式函数拟合也使用平方损失函数。特别地,一阶多项式函数拟合又叫线性函数拟合。
给定训练数据集,模型复杂度和误差之间的关系:
- 训练数据集大小
影响欠拟合和过拟合的另一个重要因素是训练数据集的大小。一般来说,如果训练数据集中样本数过少,特别是比模型参数数量(按元素计)更少时,过拟合更容易发生。此外,泛化误差不会随训练数据集里样本数量增加而增大。因此,在计算资源允许的范围之内,我们通常希望训练数据集大一些,特别是在模型复杂度较高时,例如层数较多的深度学习模型。
下面以多项式函数拟合实验为例,进行训练并实例演示欠拟合和过拟合。
二、多项式函数拟合实验
引入实验需要的包
import torch
import numpy as np
import sys
sys.path.append("/home/kesci/input")
import d2lzh1981 as d2l
print(torch.__version__)
2.1 初始化模型参数
n_train, n_test,true_w,true_b = 100,100,[1.2,-3.4,5.6], 5
features = torch.randn((n_train + n_test, 1))
print(features.shape)
poly_features = torch.cat((features,torch.pow(features,2),torch.pow(features,3)),1) # cat()为连接操作,连接维度为列(1);pow()为张量的幂操作
print(poly_features.shape)
labels = (true_w[0] * poly_features[:, 0] + true_w[1] * poly_features[:, 1]
+ true_w[2] * poly_features[:, 2] + true_b)
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float) # 添加噪声
features[:2],poly_features[:2],labels[:2]
2.2 定义、训练和测试模型
def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,
legend=None, figsize=(3.5, 2.5)):
d2l.plt.xlabel(x_label)
d2l.plt.ylabel(y_label)
d2l.plt.semilogy(x_vals, y_vals)
if x2_vals and y2_vals:
d2l.plt.semilogy(x2_vals, y2_vals, linestyle=':')
d2l.plt.legend(legend)
num_epochs, loss = 100, torch.nn.MSELoss()
def fit_and_plot(train_features, test_features, train_labels, test_labels):
# 初始化网络模型
print('train_features.shape',train_features.shape)
print(('train_features.shape[-1]',train_features.shape[-1]