MNIST手写数字识别教程
1、什么是MNIST?
MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集。
MNIST中所有样本都会将原本28*28的灰度图转换为长度为784的一维向量作为输入,其中每个元素分别对应了灰度图中的灰度值。MNIST使用一个长度为10的one-hot向量作为该样本所对应的标签,其中向量索引值对应了该样本以该索引为结果的预测概率。
在本文中,我们将在PyTorch中构建一个简单的卷积神经网络,并使用MNIST数据集训练它识别手写数字。
2、使用pytorch实现手写数字识别
2.1任务目的
如本文标题所示,MNIST手写数字识别的主要目为:训练出一个模型,让这个模型能够对手写数字图片进行分类。
2.2开发环境
为了实现本文的目标,你需要安装如下Python库
import torchvision.datasets
import torch
from torch.utils import data
from torchvision import transforms
from torch import nn
关于torch的安装教程在网上有很多,在这里就不过多赘述了,你也可以去官网上直接下载pytorch,其他的库都可以用pip直接安装。
2.3 构建数据集
torchvision中的torchvision.datasets库中提供了MNIST数据集的下载地址,因此我们可以通过框架中的内置函数将MNIST数据集下载并读取到内存中。
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.MNIST(
root='../data',train=True,transform=trans,download=True
)
mnist_test = torchvision.datasets.MNIST(
root='../data',train=False,transform=trans,download=True
)
Pytorch中提供了一种叫做DataLoader的方法来让我们进行训练,该方法自动将数据集打包成为迭代器,能够让我们很方便地进行后续的训练处理
batch_size = 64
train_dataloader = data.DataLoader(mnist_train,batch_size=batch_size)
test_dataloader = data.DataLoader(mnist_test,batch_size=batch_size)
for X,y in test_dataloader:
print("Shape of X [N,C,H,W]:",X.shape)
print("Shape of y:",y.shape,y.dtype)
break
至此,数据集已经准备完毕。
3 训练部分
3.1 构建模型
在这里使用的是一个简单的卷积神经网络,其结构如下
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork,self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28,512),
nn.ReLU(),
nn.Linear(512,512),
nn.ReLU(),
nn.Linear(512,10),
nn.ReLU()
)
def forward(self, X):
X = self.flatten(X)
logits = self.linear_relu_stack(X)
return logits
其中nn.Sequential函数能够自动将层数合并为一个模型,对于新手而言这种方式能够减少非常多的计算过程
随后,我们需要构建一个模型实例
model = NeuralNetwork().to(device=0)
print(model)
to() 方法用于将张量放入到指定的设备(如CPU或GPU中),记住的是:不同设备的张量是无法进行运算的
如果一切正常,那么输出结果如下
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
(5): ReLU()
)
)
3.2 构建迭代器与损失函数
对于简单的多分类任务,我们可以使用交叉熵损失来作为损失函数; 而对于迭代器而言,我们可以使用SGD迭代器
loss_fn = nn.CrossEntropyLoss()#交叉熵损失
optimizer = torch.optim.SGD(model.parameters(),lr = 1e-3)
模型在构建迭代器的时候需要将所有参数传入到迭代器中,可以通过**model.parameters()**方法来得到模型的所有参数。
3.3 构建训练循环
3.3.1 训练部分代码
对于训练部分,我们可以构造的模块为
def train(dataloader,model,loss_fn,optimizer):
size = len(dataloader.dataset)
for batch,(X,y) in enumerate(dataloader):
X,y = X.to(0),y.to(0)
pred = model(X)
loss = loss_fn(pred,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss,current = loss.item(), batch * len(X)
print(f"loss:{loss:>7f}[{current:>5d}/{size:>5d}]")
3.3.2 测试部分代码
对于测试部分,我们可以构造的模块为
def test(dataloader,model):
size = len(dataloader.dataset)
model.eval()
test_loss,correct = 0,0
with torch.no_grad():
for X,y in dataloader:
X,y = X.to(0),y.to(0)
pred = model(X)
test_loss += loss_fn(pred,y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= size
correct /=size
print(f"Test Error:\n Accuracy:{100*correct:>0.1f}%,Avg loss:{test_loss:>8f} \n")
3.3.3 训练循环代码
将上述两个循环进行结合,就是最终的训练循环代码了
epochs = 50
for t in range(epochs):
print(f"Epoch {t+1}\n-------------")
train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model)
print("Done!")
4.训练结果
假如一切正常,能看到以下的训练过程
Epoch 1
-------------
loss:2.301123[ 0/60000]
loss:2.298039[ 6400/60000]
loss:2.294332[12800/60000]
loss:2.282639[19200/60000]
loss:2.296278[25600/60000]
loss:2.300666[32000/60000]
loss:2.283164[38400/60000]
loss:2.287391[44800/60000]
loss:2.277942[51200/60000]
loss:2.268170[57600/60000]
Test Error:
Accuracy:28.8%,Avg loss:0.035733
最后得到最终的训练结果
Epoch 50
-------------
loss:0.911302[ 0/60000]
loss:1.000196[ 6400/60000]
loss:1.069001[12800/60000]
loss:0.644273[19200/60000]
loss:0.889767[25600/60000]
loss:1.041025[32000/60000]
loss:0.955698[38400/60000]
loss:1.179849[44800/60000]
loss:0.848570[51200/60000]
loss:1.003266[57600/60000]
Test Error:
Accuracy:65.5%,Avg loss:0.014824
Done!
Process finished with exit code 0
可以看到训练精度有待提高,后续可以换用不同的网络,或者是换一下损失函数或者优化器来得到精度的提升。