本系列笔记为莫烦PyTorch视频教程笔记 github源码
概要
在训练时 loss 已经很小,但是把训练的 NN 放到测试集中跑,loss 突然飙升,这很可能出现了过拟合(overfitting)
减低过拟合,一般可以通过:加大训练集、loss function 加入正则化项、Dropout 等途径,这里演示 Dropout
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
torch.manual_seed(1)
%matplotlib inline
准备数据
出现过拟合一般是由于训练数据过少且网络结构较复杂,为了凸显过拟合问题,这里只用 10 个数据集来进行训练
DATA_SIZE = 10
# training set
x = torch.unsqueeze(torch.linspace(-1, 1, DATA_SIZE), dim=1) # sieze (20,1)
y = x + 0.3*torch.normal(torch.zeros(DATA_SIZE, 1), torch.ones(DATA_SIZE, 1))
x, y = Variable(x), Variable(y)
# test set
test_x = torch.unsqueeze(torch.linspace(-1, 1, DATA_SIZE), dim=