1. MNIST 手写数字识别
MNIST 数据集分为两部分,分别是训练集和测试集,其中训练集含有 60000 张图片,测试集中含有 10000 张图片。从官网下载的数据集主要包括有 4 个文件:
| 文件名称 | 文件用途 |
|---|---|
| train-images-idx3-ubyte.gz | 训练集图像 |
| train-labels-idx1-ubyte.gz | 训练集 label |
| t10k-images-idx3-ubyte.gz | 测试集图像 |
| t10k-labels-idx1-ubyte.gz | 测试集 label |
参考:
MNIST 数据集介绍 1
MNIST 数据集介绍 2
2. 聚焦数据集扩充后的模型训练
Internet 中有很多关于 pytorch 实现手写数字识别的博客了,所以本文不再对这一方面作过多的叙述。更多地,本文对 MNIST 数据集进行了扩充,利用 3 中不同的数据集构成对模型进行训练,每类数据集构成都包含了 12000 张图片。这 3 种不同的数据集构成如下:
- 原始数据集(60000 张)+ 像素反转后的图片(60000 张)
- 原始数据集(60000 张)+ 对图像进行 90°, 180°, 270° 等量均类旋转后的图片(60000 张)(注意:此处的等量均类是指对每个角度都旋转了 20000 张图片,同时,这 20000 张图片中包含了数字 0-9 这十个类别的图片各 2000 张)
- 原始数据集(60000 张)+ 像素反转后的图片(30000 张)+ 等量均类旋转的图片(30000 张)
建议自己尝试进行数据分割,也可以利用分割好了的数据 click->已分割好了的数据
3. pytorch 手写数字识别基本实现
3.1完整代码及 MNIST 测试集测试结果
3.1.1代码
完整代码如下:
import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
)
self.fullyConnected = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=7 * 7 * 64, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=10),
)
def forward(self, img):
output = self.conv1(img)
output = self.conv2(output)
output = self.conv3(output)
output = self.fullyConnected(output)
return output
def get_device():
if torch.cuda.is_available():
train_device = torch.device('cuda')
else:
train_device = torch.device('cpu')
return train_device
def get_data_loader(dat_path, bat_size, trans, to_train=False):
dat_set = torchvision.datasets.MNIST(root=dat_path, train=to_train, transform=trans, download=True)
if to_train is True:
dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size, shuffle=True)
else:
dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size)
return dat_set, dat_loader
def show_part_of_image(dat_loader, row, col):
iteration = enumerate(dat_loader)
idx, (exam_img, exam_label) = next(iteration)
fig = plt.figure(num=1)
for i in range(row * col):
plt.subplot(row, col, i + 1)
plt.tight_layout()
plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')
plt.title('Number: {}'.format(exam_label[i]))
plt.xticks([])
plt.yticks([])
plt.show()
def train(network, dat_loader, device, epos, loss_function, optimizer):
for epoch in range(1, epos + 1):
network.train(mode=True)
for idx, (train_img, train_label) in enumerate(dat_loader):
train_img = train_img.to(device)
train_label = train_label.to(device)
outputs = network(train_img)
optimizer.zero_grad()
loss = loss_function(outputs, train_label)
loss.backward()
optimizer.step()
if idx % 100 == 0:
cnt = idx * len(train_img) + (epoch - 1) * len(dat_loader.dataset)
print('epoch: {}, [{}/{}({:.0f}%)], loss: {:.6f}'.format(epoch,
idx * len(train_img),
len(dat_loader.dataset),
(100 * cnt) / (
len(dat_loader.dataset) * epos),
loss.item()))
print('------------------------------------------------')
print('Training ended.')
return network
def test(network, dat_loader, device, loss_function):
test_loss_avg, correct, total = 0, 0, 0
test_loss = []
network.train(mode=False)
with torch.no_grad():
for idx, (test_img, test_label) in enumerate(dat_loader):
test_img = test_img.to(device)
test_label = test_label.to(device)
total += test_label.size(0)
outputs = network(test_img)
loss = loss_function(outputs, test_label)
test_loss.append(loss.item())
predictions = torch.argmax(outputs, dim=1)
correct += torch.sum(predictions == test_label)
test_loss_avg = np.average(test_loss)
print('Total: {}, Correct: {}, Accuracy: {:.2f}%, AverageLoss: {:.6f}'.format(total, correct,
correct / total * 100,
test_loss_avg))
def show_part_of_test_result(network, dat_loader, row, col):
iteration = enumerate(dat_loader)
idx, (exam_img, exam_label) = next(iteration)
with torch.no_grad():
outputs = network(exam_img)
fig = plt.figure()
for i in range(row * col):
plt.subplot(row, col, i + 1)
plt.tight_layout()
plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')
plt.title('Number: {}, Prediction: {}'.format(
exam_label[i], outputs.data.max(1, keepdim=True)[1][i].item()
))
plt.xticks([])
plt.yticks([])
plt.show()
batch_size, epochs = 64, 10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
my_device = get_device()
path = './data'
_, train_data_loader = get_data_loader(path, batch_size, transform, True)
print('Training data loaded.')
show_part_of_image(train_data_loader, 3, 3)
_, test_data_loader = get_data_loader(path, batch_size, transform)
print('Testing data loaded.')
cnn = CNN()
loss_func = nn.CrossEntropyLoss()
optim = torch.optim.Adam(cnn.parameters(), lr=0.01)
cnn = train(cnn, train_data_loader, my_device, epochs, loss_func, optim)
test(cnn, test_data_loader, my_device, loss_func)
show_part_of_test_result(cnn, test_data_loader, 5, 2)
torch.save(cnn, './cnn.pth')
3.1.2 MNIST 测试集测试结果
模型测试结果:

其中一些超参数如下:
batch_size: 64epochs: 10
同时,采用交叉熵 CrossEntropyLoss 来计算 loss,Adam 来进行优化:

模型在测试集上的准确率达到了 97.32%,从右侧的测试集采样结果来看,正确率也相对较高;
3.2 使用自己的图片进行测试
另外,还在画图中做了 0-9 这 10 个数字代入模型进行识别。注意:在画图中做的图片必须要是 28 * 28 的大小(当然也可以用 python 进行裁剪,这里就偷个懒~)
还需要注意的是,MNIST 数据集中的图片是黑底白字的,而通过画图做出的图片是白底黑字的,因此若想得到准确结果的话,必须要对需要测试的图片进行像素反转的预处理操作;
3.2.1 测试图片预处理代码
注意:由于将模型保存进了 cnn.pth 文件,测试时直接 torch.load('./cnn.pth') 即可(当然也可以用官方推荐的只保存参数的方法);需要注意的是:记得把网络结构的定义复制过来,否则会报错;
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as plt
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
)
self.fullyConnected = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=7 * 7 * 64, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=10),
)
def forward(self, input):
output = self.conv1(input)
output = self.conv2(output)
output = self.conv3(output)
output = self.fullyConnected(output)
return output
model = torch.load('./cnn.pth')
model.eval()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std

本文介绍了一种基于MNIST数据集的手写数字识别方法,并通过数据增强技术提高模型泛化能力。通过对原始数据集进行像素反转和图像旋转,构建了三种增强数据集,进一步提升了模型在测试集上的准确率。
最低0.47元/天 解锁文章
499

被折叠的 条评论
为什么被折叠?



