LeNet手写数字识别(数据集下载+模型创建+模型训练+PyQt界面)

1.目录结构

data.py文件:下载数据集文件

model.py文件:LeNet网络模型定义文件

train.py文件:训练下载的数据集

predict.py文件:利用训练好的模型去预测

predict_ui.py文件:用于ui界面进行预测

ui.py文件:PyQt界面程序,识别自己在画板上写的数字

2.手写数字数据集下载

创建一个data.py文件把下边代码复制进去,运行后就会开始下载数据集。这是使用的pytorch库下载的。

from torchvision import transforms
import torchvision

mnist_train = torchvision.datasets.MNIST(root='./datasets/',
                                         train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='./datasets/',
                                        train=False, download=True, transform=transforms.ToTensor())

print(len(mnist_train), len(mnist_test))  # 打印训练/测试集大小
feature, label = mnist_train[0]
print(feature.shape, label)  # 打印图像大小和标签

 下载完成后目录下会出现datasets文件夹,里边存放的就是我们下载的数据集。

运行后也会显示我们下载的数据集大小。60000个训练数据集,10000个测试数据集。

3.LeNet网络模型

LeNet网络模型定义在model.py文件里边,下边是model.py文件里边定义的LeNet网络模型。

import torch
from torch import nn


class LeNet(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(  # (-1,1,28,28)
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),  # (-1,6,28,28)
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # (-1,6,14,14)
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),  # (-1,16,10,10)
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # (-1,16,5,5)
            nn.Flatten(),
            nn.Linear(in_features=16 * 5 * 5, out_features=120),  # (-1,120)
            nn.Sigmoid(),
            nn.Linear(120, 84),  # (-1,84)
            nn.Sigmoid(),
            nn.Linear(in_features=84, out_features=10)  # (-1,10)
        )

    def forward(self, x):
        return self.model(x)


leNet = LeNet()
print(leNet)

运行后可以查看网络的结构。

4.训练模型

训练模型的由train.py文件来完成,新建一个train.py文件把下边代码放进去。

import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn



class LeNet(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(  # (-1,1,28,28)
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),  # (-1,6,28,28)
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # (-1,6,14,14)
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),  # (-1,16,10,10)
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # (-1,16,5,5)
            nn.Flatten(),
            nn.Linear(in_features=16 * 5 * 5, out_features=120),  # (-1,120)
            nn.Sigmoid(),
            nn.Linear(120, 84),  # (-1,84)
            nn.Sigmoid(),
            nn.Linear(in_features=84, out_features=10)  # (-1,10)
        )

    def forward(self, x):
        return self.model(x)


# 创建模型
leNet = LeNet()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
leNet = leNet.to(device)  # 若支持GPU加速
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.Adam(leNet.parameters(), lr=learning_rate)
total_train_step = 0  # 总训练次数
epoch = 20 # 训练轮数

# 数据
mnist_train = torchvision.datasets.MNIST(root='./datasets/',
                                         train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='./datasets/',
                                        train=False, download=True, transform=transforms.ToTensor())
dataloader_train = DataLoader(mnist_train, batch_size=64, num_workers=0)  # 每次批量加载64张
dataloader_test = DataLoader(mnist_test, batch_size=64, num_workers=0)  # 每次批量加载64张

for i in range(epoch):
    print("-----第{}轮训练开始-----".format(i + 1))
    leNet.train()  # 训练模式
    train_loss = 0
    for data in dataloader_train:
        imgs, labels = data
        imgs = imgs.to(device)  # 数据放到device里,要和model一样
        labels = labels.to(device)
        outputs = leNet(imgs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()  # 清空之前梯度
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        total_train_step += 1  # 步数
        train_loss += loss.item()


torch.save(leNet.state_dict(),'LeNet.pth') # 保存模型

训练模型,要注意这个模型保存,用的是保存模型的参数。训练结束后(我训练了20轮)目录下会生成一个LeNet.pth文件,这就是我们训练的权重。

 

5.测试训练的模型

测试训练的模型由predict.py完成。创建一个predict.py文件把下边代码放进去。可以来测试自己的手写数字。注意改成自己的图片路径。

import torch
from PIL import Image
import torchvision.transforms as transforms
import torch
from model import leNet
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
import torchvision
from torch.utils.data import DataLoader

# 定义图像预处理流水线
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # 将图片缩放到固定的尺寸,如224x224
    transforms.ToTensor(),  # 转换为张量
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # 归一化
])

# 假设图片文件名为"my_image.jpg"
img_path = "4.png"

# 读取图片
with Image.open(img_path) as img:
    # 应用预处理
    tensor_img = transform(img)

#print(tensor_img.shape)

def rgb_to_grayscale(image):
    # 假设image是一个Tensor,形状可能是(C, H, W)(通道数、高度、宽度)
    grayscale = image.mean(dim=0, keepdim=True)
    return grayscale
tensor_img = rgb_to_grayscale(tensor_img)

#print(tensor_img)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = leNet.to(device)
model.load_state_dict(torch.load("LeNet.pth"))
classes = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9",
]

x = Variable(torch.unsqueeze(tensor_img, dim=0).float(), requires_grad=False).to(device)
with torch.no_grad():   #关闭反向传播
    pred = model(x)
    # argmax(input):返回指定维度最大值的序号
    # 得到验证类别中数值最高的那一类,再对应classes中的那一类
    #print(pred)
    predicted, actual = classes[torch.argmax(pred[0])], classes[0]
    # 输出预测值与真实值
    print(f'predicted: "{predicted}"')

这里我在文件夹下放入了,我自己写的4。

预测结果正确!

6.PyQt界面程序

创建一个predict_ui.py文件把下边代码复制进去。

import torch
from PIL import Image
import torchvision.transforms as transforms
import torch
from model import leNet
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
import torchvision
from torch.utils.data import DataLoader



def get_predict():
    # 定义图像预处理流水线
    transform = transforms.Compose([
        transforms.Resize((28, 28)),  # 将图片缩放到固定的尺寸,如224x224
        transforms.ToTensor(),  # 转换为张量
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # 归一化
    ])

    # 假设图片文件名为"my_image.jpg"
    #img_path = "images/test.png"
    img = Image.open("test.png")  # 加载图片,自定义的图片名称

    # 读取图片
    #with Image.open(img_path) as img:
        # 应用预处理
    tensor_img = transform(img)

    # print(tensor_img.shape)

    def rgb_to_grayscale(image):
        # 假设image是一个Tensor,形状可能是(C, H, W)(通道数、高度、宽度)
        grayscale = image.mean(dim=0, keepdim=True)
        return grayscale

    tensor_img = rgb_to_grayscale(tensor_img)

    # print(tensor_img)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = leNet.to(device)
    model.load_state_dict(torch.load("LeNet.pth"))
    classes = [
        "0",
        "1",
        "2",
        "3",
        "4",
        "5",
        "6",
        "7",
        "8",
        "9",
    ]

    x = Variable(torch.unsqueeze(tensor_img, dim=0).float(), requires_grad=False).to(device)
    with torch.no_grad():
        pred = model(x)
        # argmax(input):返回指定维度最大值的序号
        # 得到验证类别中数值最高的那一类,再对应classes中的那一类
        # print(pred)
        predicted, actual = classes[torch.argmax(pred[0])], classes[0]
        # 输出预测值与真实值
        print(f'predicted: "{predicted}"')
        return predicted

#get_predict()

创建一个ui.py文件把下边代码复制进去。

from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from PyQt5.QtWidgets import QApplication
from PyQt5.Qt import QPainter, QPoint, QPen
from PyQt5.QtCore import Qt
from PyQt5.Qt import QWidget, QColor, QPixmap, QIcon, QSize, QCheckBox
from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QPushButton,QComboBox, QLabel, QSpinBox
from predict_ui import get_predict
import sys


def main():
    app = QApplication(sys.argv)
    mainWidget = MainWidget()
    mainWidget.show()
    exit(app.exec_())


class PaintBoard(QWidget):
    def __init__(self, Parent=None):
        super().__init__(Parent)
        self.__InitData()
        self.__InitView()
        self.setWindowTitle("画笔")

    def __InitData(self):
        self.__size = QSize(480, 460)

        self.__board = QPixmap(self.__size)
        self.__board.fill(Qt.black)  # 用白色填充画板

        self.__IsEmpty = True  # 默认为空画板
        self.EraserMode = False  # 默认为禁用橡皮擦模式

        self.__lastPos = QPoint(0, 0)  # 上一次鼠标位置
        self.__currentPos = QPoint(0, 0)  # 当前的鼠标位置

        self.__painter = QPainter()  # 新建绘图工具

        self.__thickness = 20
        self.__penColor = QColor("white")
        self.__colorList = QColor.colorNames()

    def __InitView(self):
        # 设置界面的尺寸为__size
        self.setFixedSize(self.__size)

    def Clear(self):
        # 清空画板
        self.__board.fill(Qt.black)
        self.update()
        self.__IsEmpty = True


    def ChangePenThickness(self, thickness=10):
        # 改变画笔粗细
        self.__thickness = thickness

    def IsEmpty(self):
        # 返回画板是否为空
        return self.__IsEmpty

    def GetContentAsQImage(self):
        # 获取画板内容(返回QImage)
        image = self.__board.toImage()
        return image

    def paintEvent(self, paintEvent):
        self.__painter.begin(self)
        # 0,0为绘图的左上角起点的坐标,__board即要绘制的图
        self.__painter.drawPixmap(0, 0, self.__board)
        self.__painter.end()

    def mousePressEvent(self, mouseEvent):
        # 鼠标按下时,获取鼠标的当前位置保存为上一次位置
        self.__currentPos = mouseEvent.pos()
        self.__lastPos = self.__currentPos

    def mouseMoveEvent(self, mouseEvent):
        # 鼠标移动时,更新当前位置,并在上一个位置和当前位置间画线
        self.__currentPos = mouseEvent.pos()
        self.__painter.begin(self.__board)

        if self.EraserMode == False:
            # 非橡皮擦模式
            self.__painter.setPen(QPen(self.__penColor, self.__thickness))  # 设置画笔颜色,粗细

        # 画线
        self.__painter.drawLine(self.__lastPos, self.__currentPos)
        self.__painter.end()
        self.__lastPos = self.__currentPos

        self.update()  # 更新显示

    def mouseReleaseEvent(self, mouseEvent):
        self.__IsEmpty = False  # 画板不再为空


class MainWidget(QWidget):
    def __init__(self, Parent=None):
        super().__init__(Parent)
        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()

    def __InitData(self):
        """
        初始化成员变量
        """
        self.__paintBoard = PaintBoard(self)
        # 获取颜色列表(字符串类型)
        self.__colorList = QColor.colorNames()

    def __InitView(self):
        """
        初始化界面
        """
        self.setFixedSize(700, 480)
        self.setWindowTitle("LeNet手写数字识别器")

        # 新建一个水平布局作为本窗体的主布局
        main_layout = QHBoxLayout(self)
        # 设置主布局内边距以及控件间距为10px
        main_layout.setSpacing(10)

        # 在主界面左侧放置画板
        main_layout.addWidget(self.__paintBoard)

        # 新建垂直子布局用于放置按键
        sub_layout = QVBoxLayout()

        # 设置此子布局和内部控件的间距为10px
        sub_layout.setContentsMargins(10, 10, 10, 10)

        self.__btn_Clear = QPushButton("清空画板")
        self.__btn_Clear.setParent(self)  # 设置父对象为本界面

        # 将按键按下信号与画板清空函数相关联
        self.__btn_Clear.clicked.connect(self.__paintBoard.Clear)
        sub_layout.addWidget(self.__btn_Clear)

        self.__btn_Quit = QPushButton("退出")
        self.__btn_Quit.setParent(self)  # 设置父对象为本界面
        self.__btn_Quit.clicked.connect(self.Quit)
        sub_layout.addWidget(self.__btn_Quit)



        self.__label_penThickness = QLabel(self)
        self.__label_penThickness.setText("画笔粗细")
        self.__label_penThickness.setFixedHeight(20)
        sub_layout.addWidget(self.__label_penThickness)

        self.__spinBox_penThickness = QSpinBox(self)
        self.__spinBox_penThickness.setMaximum(24)
        self.__spinBox_penThickness.setMinimum(20)
        self.__spinBox_penThickness.setValue(20)  # 默认粗细为10
        self.__spinBox_penThickness.setSingleStep(2)  # 最小变化值为2
        self.__spinBox_penThickness.valueChanged.connect(
            self.on_PenThicknessChange)  # 关联spinBox值变化信号和函数on_PenThicknessChange
        sub_layout.addWidget(self.__spinBox_penThickness)



        self.__btn_Save = QPushButton("LeNet预测")
        self.__btn_Save.setParent(self)
        self.__btn_Save.clicked.connect(self.on_btn_Save_Clicked)
        sub_layout.addWidget(self.__btn_Save)

        self.__textbox = QLineEdit(self)
        self.__textbox.setReadOnly(True)
        sub_layout.addWidget(self.__textbox)
        main_layout.addLayout(sub_layout)  # 将子布局加入主布局

    def __fillColorList(self, comboBox):
        index_black = 0
        index = 0
        for color in self.__colorList:
            if color == "black":
                index_black = index
            index += 1
            pix = QPixmap(70, 20)
            pix.fill(QColor(color))
            comboBox.addItem(QIcon(pix), None)
            comboBox.setIconSize(QSize(70, 20))
            comboBox.setSizeAdjustPolicy(QComboBox.AdjustToContents)

        comboBox.setCurrentIndex(index_black)

    def on_PenColorChange(self):
        color_index = self.__comboBox_penColor.currentIndex()
        color_str = self.__colorList[color_index]
        self.__paintBoard.ChangePenColor(color_str)

    def on_PenThicknessChange(self):
        penThickness = self.__spinBox_penThickness.value()
        self.__paintBoard.ChangePenThickness(penThickness)

    def on_btn_Save_Clicked(self): # 按钮点击事件触发
        image = self.__paintBoard.GetContentAsQImage()
        image.save('test.png')  # 默认保存为test.png文件
        ans = get_predict() # 调用函数进行预测
        self.__textbox.setText("预测结果为:" + ans)

    def on_cbtn_Eraser_clicked(self):
        if self.__cbtn_Eraser.isChecked():
            self.__paintBoard.EraserMode = True  # 进入橡皮擦模式
        else:
            self.__paintBoard.EraserMode = False  # 退出橡皮擦模式

    def Quit(self):
        self.close()


if __name__ == '__main__':
    main()

 运行ui.py文件,演示效果。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值