pytorch经典数据集-手写数字识别
一、什么是MNIST?
MNIST是计算机视觉领域中最为基础的一个数据集,也是很多人第一个神经网络模型
MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集
MNIST中所有样本都会将原本28*28的灰度图转换为长度为784的一维向量作为输入,其中每个元素分别对应了灰度图中的灰度值。MNIST使用一个长度为10的one-hot向量作为该样本所对应的标签,其中向量索引值对应了该样本以该索引为结果的预测概率。
二、详细代码介绍
MNIST手写数字识别的主要目为:训练出一个模型,让这个模型能够对手写数字图片进行分类。
首先先搞清楚步骤流程,然后才开始构建网络结构开始训练模型
导入要用到的库
utils是外部文件,自己定义的几个函数,详细代码已放文章末尾
#导入需要的各种库
import torch
#神经网络
from torch import nn
#function神经网络中常见的函数
from torch.nn import functional as F
#梯度下降优化包
from torch import optim
#图形视觉包
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot
加载数据集
#1 加载数据集
#load dataset
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)#shuffle打乱
#预览训练集数据
x, y=next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())
#画图,图片识别,识别结果
plot_image(x,y,'image_sample')
用Net模型创建三层的网络结构+加一层relu激活函数层
#2 创建网络
#制作三层线性网络层 + relu函数 网络结构
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
#三层线性 xw +b
#第一层 28*28 =》打平成一个向量 输出是中间层,一般取2^n,逐步减小
#Linear(输入,输出)
self.fc1 = nn.Linear(28*28,256)
#第二层 上一层输出是这一层的输入
self.fc2 = nn.Linear(256,64)
#第三层 是最终的输出=== 分类数有关
self.fc3 = nn.Linear(64,10)
def forward(self,x):
# x[512,1,28,28] 输入层结构:512张灰度图片,28*28
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
#x = F.relu(self.fc3(x))
#一般来说,最后一层激活函数可加可不加
x = self.fc3(x)
return x
网络训练
#3 网络训练
#迭代的次数,对数据集迭代3次
for epoch in range(3):
#每次迭代,对数据集每512张做训练
for batch_idx, (x,y) in enumerate(train_loader):
# x[512,1,28,28] 28*28===1*784 打平矩阵,维度转换
x = x.view(x.size(0),28*28)#一维 1*784
# 放入网络训练
#out:[512,10]
out = net(x)
#label用onthot编码转化成向量
y_onehot = one_hot(y)
#计算loss 欧式距离
loss = F.mse_loss(out,y_onehot)
#梯度下降
#梯度清零
optimizer.zero_grad()
#计算梯度
loss.backward()
#更新梯度 w' = w - lr * grad
optimizer.step()
#此时退出循环,得到了最好的结果【w1,w2,w3,b1,b2,b3】
if batch_idx % 10 == 0:
#每10次打印loss
print(epoch,batch_idx,loss.item())
验证测试
#4 验证
total_correct = 0
for x,y in test_loader:
x = x.view(x.size(0),28*28)
out = net(x)#[512,x]
pred = out.argmax(dim =1)#dim维度
#pred =? 相等的数量有几张 eq()相等记为1,不相等记为0
correct =pred.eq(y).sum().float().item()
total_correct+=correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:",acc)
x,y =next(iter(test_loader))
out =net(x.view(x.size(0),28*28))
pred = out.argmax(dim =1)
plot_image(x,pred,'test')
全部代码
如果要编写成一个脚本的话,把下面函数部分复制到同一个py文件就行了,就不用多创建一个py文件,为了让代码更好维护与调试建议分开它
import torch
from matplotlib import pyplot as plt#绘图库
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label, depth=10):
out = torch.zeros(label.size(0), depth)
idx = torch.LongTensor(label).view(-1, 1)
out.scatter_(dim=1, index=idx, value=1)
return out
def plot_curve(data):
fig = plt.figure()
plt.plot(range(len(data)), data, color='blue')
plt.legend(['value'], loc='upper right')
plt.xlabel('step')
plt.ylabel('value')
plt.show()
mnist.py文件分享地址需要自取
链接: https://pan.baidu.com/s/1psjbAH5wxtaAyQpRXArr6g?pwd=y88a 提取码: y88a 复制这段内容后打开百度网盘手机App,操作更方便哦
如有错误之处请指正