CNN手写数字识别

CNN手写数字识别

使用卷积神经网络(CNN)的深度学习方法进行手写数字识别。
所有代码文件全部开源共享,可直接下载到电脑上运行。链接见文末!
先看一个gif捏
识别交互操作

0.前言

CNN是目前深度学习领域最为基础的网络,手写数字0-9的识别也是图像分类任务中最为基础的demo,因此本篇博文采用CNN的方法进行手写数字识别,通过对模型的构建、训练、测试、部署应用等多方面让读者详细了解深度学习的基本方法和原理。

1.运行环境基本介绍

1.1 实现的平台

电脑操作系统Windows11
python版本3.9.1
Python解释器Pycharm
GPURTX3050
深度学习框架Pytorch

没有GPU的也同样可以训练模型

2.构建模型并训练

2.1 LeNet分类模型

此处选用CNN开山祖师LeNet作为分类模型,模型简单易懂。
LeNet模型该模型包括3个卷积层、2个池化层、1个全连接层,因此也被称为LeNet5

①创建net.py

from torch import nn
class LeNet(nn.Module):
    
    def __init__(self):
        super(LeNet, self).__init__()
        self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2, stride=1)
        self.sigmoid1 = nn.Sigmoid()
        self.sigmoid2 = nn.Sigmoid()
        self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.c5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)
        self.flatten = nn.Flatten()
        self.f6 = nn.Linear(120, 84)
        self.output = nn.Linear(84, 10)

    def forward(self, x):
        x = self.sigmoid1(self.c1(x))
        x = self.s2(x)
        x = self.sigmoid2(self.c3(x))
        x = self.s4(x)
        x = self.c5(x)
        x = self.flatten(x)
        x = self.f6(x)
        x = self.output(x)
        return x

2.2训练模型

②创建train.py

2.2.1 开导
import torch
from torch import nn
from torch.utils.data import DataLoader
from net import LeNet
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import os
from tqdm import tqdm
2.2.2定义所用的相关函数

代码解释

加载数据集root=’ . /data ’data为数据集存放地址,会自动下载
加载数据集batch_size=1616为可调参数,可根据自身显存调整
损失函数nn.CrossEntropyLoss()分类常用交叉熵损失函数
优化器torch.optim.AdamAdam优化器
# 数据转换为tensor格式
data_transforms = transforms.Compose([
    transforms.ToTensor()
])
# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transforms, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transforms, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)
# 如果有显卡,转到GPU上进行学习
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 调用net里面定义的模型,将模型数据转移到GPU上
model = LeNet().to(device)
# 定义一个损失函数,分类使用交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
# 定义一个优化器,此处使用随机梯度下降进行优化
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 学习率每隔10轮变为原来的0.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
2.2.3定义训练函数
# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):
    loss, current, n = 0.0, 0.0, 0
    loop = tqdm(enumerate(dataloader), total=len(dataloader))
    for batch, (x, y) in loop:  # 从加载的数据集中获取数据
        # 前向传播
        x, y = x.to(device), y.to(device)  # x,y转到GPU上进行运算
        output = model(x)  # 计算结果output
        cur_loss = loss_fn(output, y)  # 计算损失
        _, pred = torch.max(output, axis=1)  # 取出结果中的最大值
        cur_acc = torch.sum(y == pred)/output.shape[0]  # 计算精度
        optimizer.zero_grad()  # 优化器清0
        cur_loss.backward()  # 损失反向传播
        optimizer.step()  # 优化器更新
        loss += cur_loss.item()  # 累积损失值
        current += cur_acc.item()  # 累积准确个数
        n += 1  # 统计样本数
    print('tran_loss'+str(loss/n))
    print('tran_acc'+str(current/n))
2.2.4定义验证函数
# 验证函数
def val(dataloader, model, loss_fn):
    model.eval()  # 开启验证模式
    loss, current, n = 0.0, 0.0, 0
    with torch.no_grad():  # 关闭梯度
        loop = tqdm(enumerate(dataloader), total=len(dataloader))
        for batch, (x, y) in loop:  # 从加载的数据集中获取数据
            # 前向传播
            x, y = x.to(device), y.to(device)
            output = model(x)
            cur_loss = loss_fn(output, y)
            _, pred = torch.max(output, axis=1)
            cur_acc = torch.sum(y == pred) / output.shape[0]
            loss += cur_loss.item()
            current += cur_acc.item()
            n += 1
        print('val_loss'+str(loss/n))
        print('val_acc'+str(current/n))
        return current/n  # 返回验证精度
2.2.5 迭代30次

设置epoch=30

epoch = 30
min_acc = 0.0
for i in range(epoch):
    print(f'epoch:{i+1}\n-------')
    train(train_dataloader, model, loss_fn, optimizer)
    a = val(test_dataloader, model, loss_fn)
    # 保存最好的模型权重
    if a > min_acc:
        folder = 'save_model_test'
        if not os.path.exists(folder):  # 检测是否存在文件
            os.mkdir('save_model_test')  # 不存在文件则创建对应文件
        min_acc = a
        print('save best model')
        torch.save(model.state_dict(), 'save_model_test/best_model.pth')
print('END)  # 运行结束

执行该train.py开始训练

执行结束,得到如图文件:
其中data文件夹下为自动下载的数据集
save_model_test文件夹下为最好的一代权重
执行train.py

3.测试模型

随机从数据集选取5张图片,再通过模型测试。

③创建test.py

直接执行该文件,得到测试结果

import torch
from torch.utils.data import DataLoader
from net import LeNet
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage

data_transforms = transforms.Compose([
    transforms.ToTensor()
])
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transforms, download=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)
model = LeNet()
# 权重地址
model.load_state_dict(torch.load('./save_model_test/best_model.pth'))
# 获取结果
classes = [
    '0',
    '1',
    '2',
    '3',
    '4',
    '5',
    '6',
    '7',
    '8',
    '9',
]
# 把tensor转化为图片方便可视化
show = ToPILImage()
i = 0
for i in range(5):
    x, y = test_dataset[i][0], test_dataset[i][1]
    show(x).show()
    x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False)
    with torch.no_grad():
        pred = model(x)
        predicted, actual = classes[torch.argmax(pred[0])], classes[y]
        if predicted == actual:
            i += 1
        print(f'预测值:{predicted},真实值:{actual}')

测试结果
测试了5张,全部正确,运行中途会弹出原图片,手动关闭就行。

4.交互程序

使用Pyqt5编写交互程序,使用鼠标书写数字再传入模型进行识别。

④创建qt_test.py

import sys
import torch
from torch import nn
from PIL import ImageQt
from PyQt5.QtCore import Qt
from torchvision import transforms
from torch.autograd import Variable
from torchvision.transforms import ToPILImage
from PyQt5.Qt import QPainter, QPoint, QPen
from PyQt5.Qt import QWidget, QColor, QPixmap, QIcon, QSize
from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QPushButton, QSplitter, \
    QComboBox, QLabel, QFileDialog, QApplication


# 画板界面设计
class PaintBoard(QWidget):

    def __init__(self, Parent=None):

        super().__init__(Parent)
        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()

    def __InitData(self):

        self.__size = QSize(280, 280)   # 新建QPixmap作为画板,尺寸为__size
        self.__board = QPixmap(self.__size)
        self.__board.fill(Qt.black)     # 用白色填充画板

        self.__IsEmpty = True           # 默认为空画板
        self.EraserMode = False         # 默认为禁用橡皮擦模式

        self.__lastPos = QPoint(0, 0)     # 上一次鼠标位置
        self.__currentPos = QPoint(0, 0)  # 当前的鼠标位置

        self.__painter = QPainter()  # 新建绘图工具

        self.__thickness = 25  # 默认画笔粗细为25px
        self.__penColor = QColor("white")  # 设置默认画笔颜色为白色

    def __InitView(self):
        # 设置界面的尺寸为__size
        self.setFixedSize(self.__size)

    def Clear(self):
        # 清空画板
        self.__board.fill(Qt.black)
        self.update()
        self.__IsEmpty = True

    def IsEmpty(self):
        # 返回画板是否为空
        return self.__IsEmpty

    def GetContentAsQImage(self):
        # 获取画板内容(返回QImage)
        image = self.__board.toImage()
        return image

    def paintEvent(self, paintEvent):
        # 绘图事件
        # 绘图时必须使用QPainter的实例,此处为__painter
        # 绘图在begin()函数与end()函数间进行
        # begin(param)的参数要指定绘图设备,即把图画在哪里
        # drawPixmap用于绘制QPixmap类型的对象
        self.__painter.begin(self)
        # 0,0为绘图的左上角起点的坐标,__board即要绘制的图
        self.__painter.drawPixmap(0, 0, self.__board)
        self.__painter.end()

    def mousePressEvent(self, mouseEvent):
        # 鼠标按下时,获取鼠标的当前位置保存为上一次位置
        self.__currentPos = mouseEvent.pos()
        self.__lastPos = self.__currentPos

    def mouseMoveEvent(self, mouseEvent):
        # 鼠标移动时,更新当前位置,并在上一个位置和当前位置间画线
        self.__currentPos = mouseEvent.pos()
        self.__painter.begin(self.__board)

        if self.EraserMode == False:
            # 非橡皮擦模式
            self.__painter.setPen(QPen(self.__penColor, self.__thickness))  # 设置画笔颜色,粗细
        else:
            # 橡皮擦模式下画笔为纯白色,粗细为10
            self.__painter.setPen(QPen(Qt.white, 10))

        # 画线
        self.__painter.drawLine(self.__lastPos, self.__currentPos)
        self.__painter.end()
        self.__lastPos = self.__currentPos

        self.update()  # 更新显示

    def mouseReleaseEvent(self, mouseEvent):
        self.__IsEmpty = False  # 画板不再为空


# 网络
class LeNet(nn.Module):

    def __init__(self):
        super(LeNet, self).__init__()
        self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2, stride=1)
        self.sigmoid1 = nn.Sigmoid()
        self.sigmoid2 = nn.Sigmoid()
        self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.c5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)
        self.flatten = nn.Flatten()
        self.f6 = nn.Linear(120, 84)
        self.output = nn.Linear(84, 10)

    def forward(self, x):
        x = self.sigmoid1(self.c1(x))
        x = self.s2(x)
        x = self.sigmoid2(self.c3(x))
        x = self.s4(x)
        x = self.c5(x)
        x = self.flatten(x)
        x = self.f6(x)
        x = self.output(x)
        return x


data_transforms = transforms.Compose([
    transforms.ToTensor()
])
model = LeNet()
model.load_state_dict(torch.load('./save_model_test/best_model.pth'))

classes = [
    '0',
    '1',
    '2',
    '3',
    '4',
    '5',
    '6',
    '7',
    '8',
    '9',
]


show = ToPILImage()

def pic_net(img):
    img = img.resize((28, 28))
    y = data_transforms(img)
    x = y
    x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False)
    # 输入到网络中
    with torch.no_grad():
        pred = model(x)
        predicted = classes[torch.argmax(pred[0])]
        return predicted


class MainWidget(QWidget):

    def __init__(self, Parent=None):
        '''
        Constructor
        '''

        super().__init__(Parent)

        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()

    def __InitData(self):

        '''
                  初始化成员变量
        '''

        self.__paintBoard = PaintBoard(self)
        # 获取颜色列表(字符串类型)
        self.__colorList = QColor.colorNames()

    def __InitView(self):

        self.setFixedSize(500, 450)
        self.setWindowTitle("手写数字体验")
        self.i = 0
        self.setWindowIcon(QIcon('logo.png'))
        # 新建一个水平布局作为本窗体的主布局
        main_layout = QHBoxLayout(self)
        # 设置主布局内边距以及控件间距为10px
        main_layout.setSpacing(10)

        # 在主界面左侧放置画板
        main_layout.addWidget(self.__paintBoard)

        # 新建垂直子布局用于放置按键
        sub_layout = QVBoxLayout()

        # 设置此子布局和内部控件的间距为10px
        sub_layout.setContentsMargins(10, 10, 10, 10)

        self.__btn_Clear = QPushButton("清空手写区")
        self.__btn_Clear.setParent(self)  # 设置父对象为本界面

        # 将按键按下信号与画板清空函数相关联
        self.__btn_Clear.clicked.connect(self.__paintBoard.Clear)
        sub_layout.addWidget(self.__btn_Clear)

        self.__btn_Quit = QPushButton("退出窗口")
        self.__btn_Quit.setParent(self)  # 设置父对象为本界面
        self.__btn_Quit.clicked.connect(self.Quit)
        sub_layout.addWidget(self.__btn_Quit)

        self.__btn_Save = QPushButton("保存图片")
        self.__btn_Save.setParent(self)
        self.__btn_Save.clicked.connect(self.on_btn_Save_Clicked)
        sub_layout.addWidget(self.__btn_Save)

        self.__btn_test = QPushButton("进行识别测试")
        self.__btn_test.setParent(self)
        self.__btn_test.clicked.connect(self.test)
        sub_layout.addWidget(self.__btn_test)

        self.__label_show = QLabel()
        self.__label_show.setText('暂未进行数字识别!')
        self.__label_show.setParent(self)
        sub_layout.addWidget(self.__label_show)

        splitter = QSplitter(self)  # 占位符
        sub_layout.addWidget(splitter)

        main_layout.addLayout(sub_layout)  # 将子布局加入主布局

    def __fillColorList(self, comboBox):

        index_black = 0
        index = 0
        for color in self.__colorList:
            if color == "black":
                index_black = index
            index += 1
            pix = QPixmap(70, 20)
            pix.fill(QColor(color))
            comboBox.addItem(QIcon(pix), None)
            comboBox.setIconSize(QSize(70, 20))
            comboBox.setSizeAdjustPolicy(QComboBox.AdjustToContents)

        comboBox.setCurrentIndex(index_black)

    def on_PenColorChange(self):
        color_index = self.__comboBox_penColor.currentIndex()
        color_str = self.__colorList[color_index]
        self.__paintBoard.ChangePenColor(color_str)

    def on_PenThicknessChange(self):
        penThickness = self.__spinBox_penThickness.value()
        self.__paintBoard.ChangePenThickness(penThickness)

    def on_btn_Save_Clicked(self):
        savePath = QFileDialog.getSaveFileName(self, 'Save Your Paint', '.\\', '*.png')
        print(savePath)
        if savePath[0] == "":
            print("Save cancel")
            return
        image = self.__paintBoard.GetContentAsQImage()
        image.save(savePath[0])

    def on_cbtn_Eraser_clicked(self):
        if self.__cbtn_Eraser.isChecked():
            self.__paintBoard.EraserMode = True  # 进入橡皮擦模式
        else:
            self.__paintBoard.EraserMode = False  # 退出橡皮擦模式

    def Quit(self):
        self.close()


    def test(self):

        image = self.__paintBoard.GetContentAsQImage()
        # 获取QIamage图片并转换为PIL格式图片
        image = ImageQt.fromqimage(image).convert('L')
        self.i = pic_net(image)
        self.__label_show.setText(f'<font size=4>识别结果为:{self.i}</font>')


if __name__ == '__main__':

    app = QApplication(sys.argv)
    mainWidget = MainWidget()  # 新建一个主界面
    mainWidget.show()  # 显示主界面
    exit(app.exec_())  # 进入消息循环

执行结果图片
交互

5.完整代码分享

5.1 Gitee代码

链接: https://gitee.com/ming-ming-0201/handwritten-number-recognition.git

5.2 百度网盘

链接: https://pan.baidu.com/s/1CYK9sePbxNpxnVIEeLpXUw?pwd=kzdq
提取码:kzdq

  • 5
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
卷积神经网络(Convolutional Neural Networks,简称CNN)是一种在图像处理中非常重要的神经网络结构。它可以用于手写体数字识别等任务。下面是一个使用CNN进行手写体数字识别的示例: ```python import tensorflow as tf from tensorflow.keras.datasets import mnist # 加载MNIST数据集 (x_train, y_train), (x_test, y_test) = mnist.load_data() # 数据预处理 x_train = x_train.reshape(-1, 28, 28, 1) / 255.0 x_test = x_test.reshape(-1, 28, 28, 1) / 255.0 y_train = tf.keras.utils.to_categorical(y_train, num_classes=10) y_test = tf.keras.utils.to_categorical(y_test, num_classes=10) # 构建卷积神经网络模型 model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), tf.keras.layers.MaxPooling2D((2, 2)), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) # 编译模型 model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # 训练模型 model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_test, y_test)) # 评估模型 test_loss, test_acc = model.evaluate(x_test, y_test) print('Test accuracy:', test_acc) ``` 这个示例使用了TensorFlow和Keras库来构建和训练一个卷积神经网络模型,用于手写体数字识别任务。模型首先加载了MNIST数据集,并进行了数据预处理。然后,通过添加卷积层、池化层、全连接层和输出层来构建了一个简单的卷积神经网络模型。最后,使用训练集对模型进行训练,并使用测试集评估模型的准确性。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值