mnist数据集作为一个经典的数据集,不少机器学习深度学习的初学者都对其下过手,它干净的数据集、训练难度也不是很高,作为初学练手,很适合了。今天我也来拜会一下这位经典大哥,看看到底感觉如何。
本次训练使用RNN的网络模型,其实其他模型也可以,但是这两天不是正好温习RNN嘛,就刚好学以致用嘛。
请套路先生上场吧:
1.导包定参
当你新建完一个python文件不管是py文件还是ipynb文件,不管会不会,先把常用的包导入进来再说,常定义几个超参数也请上来:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
#定义超参数
EPOCH = 10
BATCH_SIZE = 64
LR = 0.001
TIME_STEP = 28
INPUT_SIZE = 28
2.数据加载
先把数据加载进来,虽然mnist数据集可以在线下载,但可能比较慢,我更喜欢先把数据下载到本地再加载进来,这一块的处理和cifar-10那篇就很像了。可以去回顾一下:
#数据加载
train_data = datasets.MNIST (root='./mnist',train=True,transform=transforms.ToTensor(),download=True )
test_data = datasets.MNIST(root='./mnist',train=False,
transform=transforms.ToTensor(),download=True )
print(train_data)
print(test_data)
我们把数据的一些特征打印出来看一眼:
Dataset MNIST
Number of datapoints: 60000
Root location: ./mnist
Split: Train
StandardTransform
Transform: ToTensor()
Dataset MNIST
Number of datapoints: 10000
Root location: ./mnist
Split: Test
StandardTransform
Transform: ToTensor()
可以看到训练集60000个样本,测试集10000个,现在已经都被我们转为tensor格式了。
下面该干嘛了,数据预处理哈,归一化、分批啥的都可以搞起了。
3.数据预处理
图片数据可视化
数据是图像的话我比较喜欢先搞一张出来看看:
plt.imshow(train_data.train_data[0].numpy(),cmap='gray'