动手学深度学习PyTorch版 | (3)过拟合、欠拟合及其解决方案

一、过拟合、欠拟合概念

训练模型中经常出现的两类典型问题:

  • 欠拟合:模型无法得到较低的训练误差
  • 过拟合:模型的训练误差远小于它在测试数据集上的误差

在实践中,我们要尽可能同时应对欠拟合和过拟合。有很多因素可能导致这两种拟合问题,在这里我们重点讨论两个因素:模型复杂度和训练数据集大小。

- 模型复杂度

为了解释模型复杂度,我们以多项式函数拟合为例。给定一个由标量数据特征 x 和对应的标量标签 y 组成的训练数据集,多项式函数拟合的目标是找一个K阶多项式函数:
在这里插入图片描述
来近似 y。在上式中, w k w_k wk是模型的权重参数,b是偏差参数。与线性回归相同,多项式函数拟合也使用平方损失函数。特别地,一阶多项式函数拟合又叫线性函数拟合。

给定训练数据集,模型复杂度和误差之间的关系:

在这里插入图片描述
- 训练数据集大小

影响欠拟合和过拟合的另一个重要因素是训练数据集的大小。一般来说,如果训练数据集中样本数过少,特别是比模型参数数量(按元素计)更少时,过拟合更容易发生。此外,泛化误差不会随训练数据集里样本数量增加而增大。因此,在计算资源允许的范围之内,我们通常希望训练数据集大一些,特别是在模型复杂度较高时,例如层数较多的深度学习模型。

下面以多项式函数拟合实验为例,进行训练并实例演示欠拟合和过拟合。

二、多项式函数拟合实验

引入实验需要的包

import torch
import numpy as np
import sys
sys.path.append("/home/kesci/input")
import d2lzh1981 as d2l
print(torch.__version__)

2.1 初始化模型参数

n_train, n_test,true_w,true_b = 100,100,[1.2,-3.4,5.6], 5
features = torch.randn((n_train + n_test, 1))
print(features.shape)
poly_features = torch.cat((features,torch.pow(features,2),torch.pow(features,3)),1)  # cat()为连接操作,连接维度为列(1);pow()为张量的幂操作
print(poly_features.shape)
labels = (true_w[0] * poly_features[:, 0] + true_w[1] * poly_features[:, 1]
          + true_w[2] * poly_features[:, 2] + true_b)
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)  # 添加噪声

在这里插入图片描述

features[:2],poly_features[:2],labels[:2]

在这里插入图片描述

2.2 定义、训练和测试模型

def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,
             legend=None, figsize=(3.5, 2.5)):
    d2l.plt.xlabel(x_label)
    d2l.plt.ylabel(y_label)
    d2l.plt.semilogy(x_vals, y_vals)
    if x2_vals and y2_vals:
        d2l.plt.semilogy(x2_vals, y2_vals, linestyle=':')
        d2l.plt.legend(legend)
num_epochs, loss = 100, torch.nn.MSELoss()

def fit_and_plot(train_features, test_features, train_labels, test_labels):
    # 初始化网络模型
    print('train_features.shape',train_features.shape)
    print(('train_features.shape[-1]',train_features.shape[-1]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值