- 和其他教程一样,莫烦大神也用MNIST作为CNN的入门
一. 调用库
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision # 数据库模块
import matplotlib.pyplot as plt
from torch.autograd import Variable # torch 中 Variable 模块
- 对一些库进行描述
- torch.utils.data
用于batchsize训练的库 - torchvision
用于加载数据集的库
1.首先下载数据集
# Mnist 手写数字
train_data = torchvision.datasets.MNIST(
root='./mnist/', # 保存或者提取位置
train=True, # True为训练集,False为测试集
transform=torchvision.transforms.ToTensor(), #
# torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间 0-255->(0,1)
download=DOWNLOAD_MNIST, # True下载,False不下载
)
2.训练集设置
# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) #shuffle True为乱序
- 这里解释下batch_size
,网络不可能同时训练所有的数据,把数据集打乱分批,一批一批的送进去训练。
- 加载测试集
1.这里对测试数据也要注意归一化‘/255’
# 为了节约时间, 我们测试时只测试前2000个
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]
4.搭CNN模型
-这里参考我上一篇博客,采用先建立流程图,在将流程图连接的形式。
# 继承nn.Module
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn