pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

本文介绍了一种基于MNIST数据集的手写数字识别方法,并通过数据增强技术提高模型泛化能力。通过对原始数据集进行像素反转和图像旋转,构建了三种增强数据集,进一步提升了模型在测试集上的准确率。
该文章已生成可运行项目,

1. MNIST 手写数字识别

MNIST 数据集分为两部分,分别是训练集和测试集,其中训练集含有 60000 张图片,测试集中含有 10000 张图片。从官网下载的数据集主要包括有 4 个文件:

文件名称 文件用途
train-images-idx3-ubyte.gz 训练集图像
train-labels-idx1-ubyte.gz 训练集 label
t10k-images-idx3-ubyte.gz 测试集图像
t10k-labels-idx1-ubyte.gz 测试集 label

参考:
MNIST 数据集介绍 1
MNIST 数据集介绍 2

2. 聚焦数据集扩充后的模型训练

Internet 中有很多关于 pytorch 实现手写数字识别的博客了,所以本文不再对这一方面作过多的叙述。更多地,本文对 MNIST 数据集进行了扩充,利用 3 中不同的数据集构成对模型进行训练,每类数据集构成都包含了 12000 张图片。这 3 种不同的数据集构成如下:

  • 原始数据集(60000 张)+ 像素反转后的图片(60000 张)
  • 原始数据集(60000 张)+ 对图像进行 90°, 180°, 270° 等量均类旋转后的图片(60000 张)(注意:此处的等量均类是指对每个角度都旋转了 20000 张图片,同时,这 20000 张图片中包含了数字 0-9 这十个类别的图片各 2000 张)
  • 原始数据集(60000 张)+ 像素反转后的图片(30000 张)+ 等量均类旋转的图片(30000 张)

建议自己尝试进行数据分割,也可以利用分割好了的数据 click->已分割好了的数据

3. pytorch 手写数字识别基本实现

3.1完整代码及 MNIST 测试集测试结果

3.1.1代码

完整代码如下:

import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

        self.fullyConnected = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=7 * 7 * 64, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=10),
        )

    def forward(self, img):
        output = self.conv1(img)
        output = self.conv2(output)
        output = self.conv3(output)
        output = self.fullyConnected(output)
        return output


def get_device():
    if torch.cuda.is_available():
        train_device = torch.device('cuda')
    else:
        train_device = torch.device('cpu')

    return train_device


def get_data_loader(dat_path, bat_size, trans, to_train=False):
    dat_set = torchvision.datasets.MNIST(root=dat_path, train=to_train, transform=trans, download=True)
    if to_train is True:
        dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size, shuffle=True)
    else:
        dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size)

    return dat_set, dat_loader


def show_part_of_image(dat_loader, row, col):
    iteration = enumerate(dat_loader)
    idx, (exam_img, exam_label) = next(iteration)

    fig = plt.figure(num=1)
    for i in range(row * col):
        plt.subplot(row, col, i + 1)
        plt.tight_layout()
        plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')
        plt.title('Number: {}'.format(exam_label[i]))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def train(network, dat_loader, device, epos, loss_function, optimizer):
    for epoch in range(1, epos + 1):
        network.train(mode=True)
        for idx, (train_img, train_label) in enumerate(dat_loader):
            train_img = train_img.to(device)
            train_label = train_label.to(device)

            outputs = network(train_img)
            optimizer.zero_grad()
            loss = loss_function(outputs, train_label)
            loss.backward()
            optimizer.step()

            if idx % 100 == 0:
                cnt = idx * len(train_img) + (epoch - 1) * len(dat_loader.dataset)
                print('epoch: {}, [{}/{}({:.0f}%)], loss: {:.6f}'.format(epoch,
                                                                         idx * len(train_img),
                                                                         len(dat_loader.dataset),
                                                                         (100 * cnt) / (
                                                                                 len(dat_loader.dataset) * epos),
                                                                         loss.item()))
        print('------------------------------------------------')
    print('Training ended.')

    return network


def test(network, dat_loader, device, loss_function):
    test_loss_avg, correct, total = 0, 0, 0
    test_loss = []
    network.train(mode=False)

    with torch.no_grad():
        for idx, (test_img, test_label) in enumerate(dat_loader):
            test_img = test_img.to(device)
            test_label = test_label.to(device)

            total += test_label.size(0)

            outputs = network(test_img)
            loss = loss_function(outputs, test_label)
            test_loss.append(loss.item())

            predictions = torch.argmax(outputs, dim=1)
            correct += torch.sum(predictions == test_label)
        test_loss_avg = np.average(test_loss)
        print('Total: {}, Correct: {}, Accuracy: {:.2f}%, AverageLoss: {:.6f}'.format(total, correct,
                                                                                      correct / total * 100,
                                                                                      test_loss_avg))


def show_part_of_test_result(network, dat_loader, row, col):
    iteration = enumerate(dat_loader)
    idx, (exam_img, exam_label) = next(iteration)

    with torch.no_grad():
        outputs = network(exam_img)

        fig = plt.figure()
        for i in range(row * col):
            plt.subplot(row, col, i + 1)
            plt.tight_layout()
            plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')
            plt.title('Number: {}, Prediction: {}'.format(
                exam_label[i], outputs.data.max(1, keepdim=True)[1][i].item()
            ))
            plt.xticks([])
            plt.yticks([])
        plt.show()


batch_size, epochs = 64, 10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
my_device = get_device()

path = './data'
_, train_data_loader = get_data_loader(path, batch_size, transform, True)
print('Training data loaded.')

show_part_of_image(train_data_loader, 3, 3)

_, test_data_loader = get_data_loader(path, batch_size, transform)
print('Testing data loaded.')

cnn = CNN()
loss_func = nn.CrossEntropyLoss()
optim = torch.optim.Adam(cnn.parameters(), lr=0.01)

cnn = train(cnn, train_data_loader, my_device, epochs, loss_func, optim)
test(cnn, test_data_loader, my_device, loss_func)

show_part_of_test_result(cnn, test_data_loader, 5, 2)

torch.save(cnn, './cnn.pth')

3.1.2 MNIST 测试集测试结果

模型测试结果:
在这里插入图片描述
其中一些超参数如下:

  • batch_size: 64
  • epochs: 10

同时,采用交叉熵 CrossEntropyLoss 来计算 loss,Adam 来进行优化:
在这里插入图片描述
模型在测试集上的准确率达到了 97.32%,从右侧的测试集采样结果来看,正确率也相对较高;

3.2 使用自己的图片进行测试

另外,还在画图中做了 0-9 这 10 个数字代入模型进行识别。注意:在画图中做的图片必须要是 28 * 28 的大小(当然也可以用 python 进行裁剪,这里就偷个懒~)
还需要注意的是,MNIST 数据集中的图片是黑底白字的,而通过画图做出的图片是白底黑字的,因此若想得到准确结果的话,必须要对需要测试的图片进行像素反转的预处理操作;

3.2.1 测试图片预处理代码

注意:由于将模型保存进了 cnn.pth 文件,测试时直接 torch.load('./cnn.pth') 即可(当然也可以用官方推荐的只保存参数的方法);需要注意的是:记得把网络结构的定义复制过来,否则会报错;

import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as plt


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

        self.fullyConnected = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=7 * 7 * 64, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=10),
        )

    def forward(self, input):
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.conv3(output)
        output = self.fullyConnected(output)
        return output


model = torch.load('./cnn.pth')
model.eval()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std
该文章已生成可运行项目
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值