深度学习入门的第一个一般都是MNIST,很多框架都是这个,就像编程语言第一个程序是hello world!,单片机第一个程序是点亮一个LED一样。
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
torch.manual_seed(1) # reproducible
EPOCH = 8
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
# 下载数据集
if not (os.path.exists('./mnist/')) or not os.listdir('./mnist/'):
DOWNLOAD_MNIST = True
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True,
# 将一张图片或numpy数组转成(C × H ×W)torch.FloatTensor并归一化[0.0,0.1]
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST,
)
print(train_data.train_data.size()) # (60000, 28,28)
print(train_data.train_labels.size()) # (60000)
plt.imshow(train_data.train_data[1].numpy(), cmap='gray') # train_data[0]
plt.title('%i' % train_da