MNIST手写体识别

今天实现了一下MNIST手写体数据集的训练,代码中并没有使用测试集。

"""
@Title: 训练手写体0-9数字
@Time: 2023/11/26 14:01
@Author: Michael
"""

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from PIL import Image


# 搭建神经网络
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),  # (-1,6,28,28)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # (-1,6,14,14)
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),  # (-1,16,10,10)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # (-1,16,5,5)
            nn.Flatten(),
            nn.Linear(in_features=16 * 5 * 5, out_features=120),  # (-1,120)
            nn.ReLU(),
            nn.Linear(120, 84),  # (-1,84)
            nn.ReLU(),
            nn.Linear(in_features=84, out_features=10)  # (-1,10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


# 训练
def train():
    # 初始化
    lr = 0.01  # 学习率
    batch_size = 100  # 批次大小
    epoch = 50  # 设置训练参数
    model = Model()  # 创建网络模型
    loss_fun = nn.CrossEntropyLoss()  # 创建损失函数
    optimizer = torch.optim.SGD(model.parameters(),  # 创建优化器
                                lr=lr)
    writer = SummaryWriter(log_dir='logs')  # 创建日志
    if torch.cuda.is_available():
        model = model.cuda()
        loss_fun = loss_fun.cuda()

    # 下载训练集
    train_set = torchvision.datasets.MNIST(root="mnist_dataset",
                                           train=True,
                                           transform=torchvision.transforms.ToTensor())
    # 加载数据集
    train_data_loader = DataLoader(dataset=train_set, batch_size=batch_size)

    # 训练
    for i in range(epoch):
        loss_sum = 0
        # 训练集
        for data in train_data_loader:
            img, target = data
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()

            optimizer.zero_grad()  # 梯度清零
            outputs = model(img)  # 向前传播
            loss = loss_fun(outputs, target)  # loss

            loss.backward()  # 反向传播
            optimizer.step()  # 参数更新
            loss_sum += loss.item()

        # 写入日志
        print("训练{}次,train_loss:{}".format(i + 1, loss_sum))
        writer.add_scalar("train_loss", loss_sum, i + 1)

    # 保存训练结果
    torch.save(model, "model.pkl")
    writer.close()


# 应用
def practice(img_path, model_path):
    # 加载图片
    img = Image.open(img_path)
    img = img.convert('L')

    # 预处理
    pro_img = torchvision.transforms.Compose([torchvision.transforms.Resize((28, 28)),
                                              torchvision.transforms.ToTensor()])
    img = pro_img(img)
    img = torch.reshape(img, (1, 1, 28, 28))

    # 加载模型
    model = torch.load(model_path)

    # 预测输出
    if torch.cuda.is_available():
        img = img.cuda()
    output = model(img)
    output = output.argmax(1)

    dict_target = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    print('识别类型为:{}'.format(dict_target[output]))


if __name__ == '__main__':
    practice("1.png", "model.pkl")
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值