【Pytorch】基于CNN识别手写汉字的识别

继上次写的MNIST数字识别代码,我又开始尝试手写汉字的识别啦
于是上网查资料,本篇博客主要参考了此博客,并在其基础上增加了一些给像我一样小白同学的内容,若此博客内容侵犯了您的权益,请与我联系及时删除 😃
重点:原作者博客:「Pytorch」CNN实现手写汉字识别(数据集制作,网络搭建,训练验证测试全部代码)
在这先感谢此博客大佬的无私分享,给我们这些小白提供了巨大的帮助!

首先放上一些对下文理解有帮助的函数介绍,有像我一样什么也不懂的同学可以参考一下 😃
可能有点多,都是容易理解的内容,需要耐心哦
如果有大佬知道好的网站,欢迎推荐啊!! 廖雪峰老师的教程很不错!

① 导入模块

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

② 超参数定义

EPOCH = 10 # 训练次数
BATCH_SIZE = 50 # 数据集划分
LR = 0.001 # 学习率

③ 提取数据集的路径

将每个汉字的图片集都标上标签,这里有100个数字的图片集
在这里插入图片描述
全部数据资料在此文章作者的网盘中,数据集超级全面但是也超级大,下载需要一定时间,大家可以到这里下载上图中的数据运行操作一下

def classes_txt(root, out_path, num_class=None):

    dirs = os.listdir(root) # 列出根目录下所有类别所在文件夹名
    if not num_class:		# 不指定类别数量就读取所有
        num_class = len(dirs)

    if not os.path.exists(out_path): # 输出文件路径不存在就新建
        f = open(out_path, 'w')
        f.close()
	# 如果文件中本来就有一部分内容,只需要补充剩余部分
	# 如果文件中数据的类别数比需要的多就跳过
    with open(out_path, 'r+') as f:
        try:
            end = int(f.readlines()[-1].split('/')[-2]) + 1
        except:
            end = 0
        if end < num_class - 1:
            dirs.sort()
            dirs = dirs[end:num_class]
            for dir in dirs:
                files = os.listdir(os.path.join(root, dir))
                for file in files:
                    f.write(os.path.join(root, dir, file) + '\n')

④ 数据集的设置

这里是程序中较难理解的一处(对于我来说…),一些函数的使用方法我都放在了开头

class MyDataset(Dataset):
    def __init__(self, txt_path, num_class, transforms=None):
        super(MyDataset, self).__init__()
        images = [] # 存储图片路径
        labels = [] # 存储类别名,在本例中是数字
        # 打开上一步生成的txt文件
        with open(txt_path, 'r') as f:
            for line in f:
                if int(line.split('\\')[-2]) >= num_class:  # 只读取前 num_class 个类
                    break
                line = line.strip('\n')
                images.append(line)
                labels.append(int(line.split('\\')[-2]))
        self.images = images
        self.labels = labels
        self.transforms = transforms # 图片需要进行的变换,ToTensor()等等

    def __getitem__(self, index):
        image = Image.open(self.images[index]).convert('RGB') # 用PIL.Image读取图像
        label = self.labels[index]
        if self.transforms is not None:
            image = self.transforms(image) # 进行变换
        return image, label

    def __len__(self):
        return len(self.labels)

⑤ 神经网络的搭建

顺序为:卷积→池化→卷积→全连接→全连接→输出100个汉字的概率

class NetSmall(nn.Module):
    def __init__(self):
        super(NetSmall, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3) # 3个参数分别是in_channels,out_channels,kernel_size,还可以加padding
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(2704, 512)
        self.fc2 = nn.Linear(512, 84)
        self.fc3 = nn.Linear(84, args.num_class) # 命令行参数,后面解释

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 2704)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

⑥ 函数主程序

# 首先将训练集和测试集文件途径和文件名以txt保存在一个文件夹中,路径自行定义
root = 'E:\pytorch_test/HWDB1_data' # 这是我文件的储存位置
classes_txt(root + '/train', root+'/train.txt')
classes_txt(root + '/test', root+'/test.txt')

# 由于数据集图片尺寸不一,因此要进行resize(重设大小)
transform = transforms.Compose([transforms.Resize((64,64)), # 将图片大小重设为 64 * 64
                                transforms.Grayscale(),
                                transforms.ToTensor()])

# 提取训练集和测试集图片的路径
train_set = MyDataset(root + '/train.txt', num_class=100, transforms=transform) # num_class 选取100种汉字  提出图片和标签
test_set = MyDataset(root + '/test.txt', num_class =100, transforms = transform)
# 放入迭代器中
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) 
test_loader = DataLoader(test_set, batch_size=5473, shuffle=True) 
# 这里的5473是因为测试集为5973张图片,当进行迭代时取第二批500个图片进行测试
for step, (x,y) in enumerate(test_loader):
    test_x, labels_test = x.to(device), y.to(device)

⑦ 参数优化

model = NetSmall()
optimizer = torch.optim.Adam(model.parameters(), lr=LR) # 参数优化
loss_func = nn.CrossEntropyLoss() #分类误差计算函数
device = torch.device('cpu')
model.to(device)

⑧ 模型训练

最后一步了,每50步便输出模型的准确率

for epoch in range(EPOCH):
    for step, (x,y) in enumerate(train_loader):
        picture, labels = x.to(device), y.to(device)
        
        output = model(picture)
        loss = loss_func(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 50 == 0:
            test_output = model(test_x)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = (pred_y == labels_test).sum().item() / labels_test.size(0)
            print('Epoch:', epoch, '| train loss:%.4f' % loss.data, '| test accuracy:', accuracy)
            # 输出训练次数、误差、准确率
print('Finish training')

运行结果

准确挺高的,可能还是数据集太少了(笑哭)

Epoch: 8 | train loss:0.1372 | test accuracy: 0.834
Epoch: 8 | train loss:0.2623 | test accuracy: 0.798
Epoch: 8 | train loss:0.2576 | test accuracy: 0.816
Epoch: 8 | train loss:0.1368 | test accuracy: 0.816
Epoch: 8 | train loss:0.1696 | test accuracy: 0.812
Epoch: 8 | train loss:0.2071 | test accuracy: 0.82
Epoch: 8 | train loss:0.4139 | test accuracy: 0.832
Epoch: 8 | train loss:0.2181 | test accuracy: 0.84
Epoch: 8 | train loss:0.1976 | test accuracy: 0.804
Epoch: 9 | train loss:0.0905 | test accuracy: 0.814
Epoch: 9 | train loss:0.0659 | test accuracy: 0.806
Epoch: 9 | train loss:0.1603 | test accuracy: 0.81
Epoch: 9 | train loss:0.0418 | test accuracy: 0.81
Epoch: 9 | train loss:0.1181 | test accuracy: 0.816
Epoch: 9 | train loss:0.1963 | test accuracy: 0.808
Epoch: 9 | train loss:0.1848 | test accuracy: 0.814
Epoch: 9 | train loss:0.1844 | test accuracy: 0.812
Epoch: 9 | train loss:0.1395 | test accuracy: 0.82
Epoch: 9 | train loss:0.1666 | test accuracy: 0.816
Finish training

当准确率出来的那一瞬间,心中的石头终于放下了,又解锁一个成就,还是要感谢这位博主,为我提供了很大的帮助,大家可以去看原来的博客的代码哦,我在原博主的代码上,根据我学的知识和上一篇博客衔接起来,改写了一小部分地方,可以看一下我上一篇Pytorch实现CNN识别手写数字MNIST的博客哦~

寒假的最后一天终于完成了任务,明天就要开始上网课啦,上面的代码和数据集我都放在下面,有需要的同学可自取哦,如果有哪里不明白,随时私信我,我及时为你们补充上!

100个汉字图片集+标签
Pytorch实现CNN汉字的识别

©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页