「Pytorch」CNN实现手写汉字识别(数据集制作,网络搭建,训练验证测试全部代码)

之前毕业设计用TensorFlow做了手写汉字识别,使用的中科院的数据集。
参考了一篇博客: TensorFlow与中文手写汉字识别
现在用Pytorch复现一下。
Github下载链接在文末

如果有问题可以在评论区评论或者私信我,提问之前还请点个赞支持一下头秃博主哦

环境:
Pytorch:1.0.1 GPU版
Ubuntu:16.04
Python:3.5.2

1 数据集整理:

分为 train 和 test 文件夹,每个文件夹下每一类都分一个子文件夹并编号。
在这里插入图片描述
这是为了方便用 Python 做一个 txt 文件,指明所有图片数据的路径。在自定义数据集类的时候会用到。

如果你没有数据集可以参考 TensorFlow与中文手写汉字识别 前面的部分下载及处理数据集。

2 import

import os
import torch
import torch.nn as nn
import torch.nn.functional as  F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import argparse # 提取命令行参数

3 提取图片路径

一个函数就可以实现了:

def classes_txt(root, out_path, num_class=None):
    '''
    write image paths (containing class name) into a txt file.
    :param root: data set path
    :param out_path: txt file path
    :param num_class: how many classes needed
    :return: None
    '''
    dirs = os.listdir(root) # 列出根目录下所有类别所在文件夹名
    if not num_class:		# 不指定类别数量就读取所有
        num_class = len(dirs)

    if not os.path.exists(out_path): # 输出文件路径不存在就新建
        f = open(out_path, 'w')
        f.close()
	# 如果文件中本来就有一部分内容,只需要补充剩余部分
	# 如果文件中数据的类别数比需要的多就跳过
    with open(out_path, 'r+') as f:
        try:
            end = int(f.readlines()[-1].split('/')[-2]) + 1
        except:
            end = 0
        if end < num_class - 1:
            dirs.sort()
            dirs = dirs[end:num_class]
            for dir in dirs:
                files = os.listdir(os.path.join(root, dir))
                for file in files:
                    f.write(os.path.join(root, dir, file) + '\n')

4 自定义数据集

只需要重写 Dataset 里的 __init__, __getitem__, __len__ 就可以了。
__getitem__在训练的时候返回输入网络的数据,图片和标签等等,需要和训练测试的程序配合。
__len__ 返回数据集长度。

class MyDataset(Dataset):
    def __init__(self, txt_path, num_class, transforms=None):
        super(MyDataset, self).__init__()
        images = [] # 存储图片路径
        labels = [] # 存储类别名,在本例中是数字
        # 打开上一步生成的txt文件
        with open(txt_path, 'r') as f:
            for line in f:
                if int(line.split('/')[-2]) >= num_class:  # 只读取前 num_class 个类
                    break
                line = line.strip('\n')
                images.append(line)
                labels.append(int(line.split('/')[-2]))
        self.images = images
        self.labels = labels
        self.transforms = transforms # 图片需要进行的变换,ToTensor()等等

    def __getitem__(self, index):
        image = Image.open(self.images[index]).convert('RGB') # 用PIL.Image读取图像
        label = self.labels[index]
        if self.transforms is not None:
            image = self.transforms(image) # 进行变换
        return image, label

    def __len__(self):
        return len(self.labels)

5 搭建神经网络

这里用一个简单的网络进行示例。
两层卷积,三层全连接,20个类别的情况下可以训练至95%以上的准确率。

class NetSmall(nn.Module):
    def __init__(self):
        super(NetSmall, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3) # 3个参数分别是in_channels,out_channels,kernel_size,还可以加padding
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(2704, 512)
        self.fc2 = nn.Linear(512, 84)
        self.fc3 = nn.Linear(84, args.num_class) # 命令行参数,后面解释

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 2704)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

补充:关于torch.nn 和 torch.nn.functional 的区别
实现的功能比较近似,functional 是以函数的方式实现的,nn 是以类的方式实现的。所以 nn 封装的更好,可以在反向传播时实现自动保存导数等功能,但是在具体实现时还是调用了 functional 的函数。
如果使用了dropout的话需要更改模型的state, model.train(), model.eval()
参考:https://www.zhihu.com/question/66782101

6 train, validation and inference

这三个函数实现比较相似,train()会增加一层遍历数据集的循环,以及计算loss和反向传播。

def train():
	# 由于我的数据集图片尺寸不一,因此要进行resize,这里还可以加入数据增强,灰度变换,随机剪切等等
    transform = transforms.Compose([transforms.Resize((args.image_size, args.image_size)),
                                    transforms.Grayscale(),
                                    transforms.ToTensor()])

    train_set = MyDataset(args.root + '/train.txt', num_class=args.num_class, transforms=transform)
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
	# 选择使用的设备
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(device)

    model = NetSmall()
    model.to(device)
	# 训练模式
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
	# 由命令行参数决定是否从之前的checkpoint开始训练
    if args.restore:
        checkpoint = torch.load(args.log_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        loss = checkpoint['loss']
        epoch = checkpoint['epoch']
    else:
        loss = 0.0
        epoch = 0

    while epoch < args.epoch:
        running_loss = 0.0

        for i, data in enumerate(train_loader):
        # 这里取出的数据就是 __getitem__() 返回的数据
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()
            outs = model(inputs)
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if i % 200 == 199:  # every 200 steps
                print('epoch %5d: batch: %5d, loss: %f' % (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
		# 保存 checkpoint
        if epoch % 10 == 9:
            print('Save checkpoint...')
            torch.save({'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss},
                       args.log_path)
        epoch += 1

    print('Finish training')


def validation():
    transform = transforms.Compose([transforms.Resize((args.image_size, args.image_size)),
                                    transforms.Grayscale(),
                                    transforms.ToTensor()])

    test_set = MyDataset(args.root + '/test.txt', num_class=args.num_class, transforms=transform)
    test_loader = DataLoader(test_set, batch_size=args.batch_size)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = NetSmall()
    model.to(device)

    checkpoint = torch.load(args.log_path)
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    total = 0.0
    correct = 0.0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            inputs, labels = data[0].cuda(), data[1].cuda()
            outputs = model(inputs)
            _, predict = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += sum(int(predict == labels)).item()
            # 根据评论区反馈,如果上面这句报错,可以换成下面这句试试:
            # correct += (predict == labels).sum().item()

            if i % 100 == 99:
                print('batch: %5d,\t acc: %f' % (i + 1, correct / total))
    print('Accuracy: %.2f%%' % (correct / total * 100))


def inference():
    print('Start inference...')
    transform = transforms.Compose([transforms.Resize((args.image_size, args.image_size)),
                                    transforms.Grayscale(),
                                    transforms.ToTensor()])

    f = open(args.root + '/test.txt')
    num_line = sum(line.count('\n') for line in f)
    f.seek(0, 0)
    # 在文件中随机取一个路径
    line = int(torch.rand(1).data * num_line - 10) # -10 for '\n's are more than lines
    while line > 0:
        f.readline()
        line -= 1
    img_path = f.readline().rstrip('\n')
    f.close()
    label = int(img_path.split('/')[-2])
    print('label:\t%4d' % label)
    input = Image.open(img_path).convert('RGB')
    input = transform(input)
    # 网络默认接受4维数据,即[Batch, Channel, Heigth, Width],所以要加1个维度
    input = input.unsqueeze(0)
    model = NetSmall()
    model.eval()
    checkpoint = torch.load(args.log_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    output = model(input)
    _, pred = torch.max(output.data, 1)
    
    print('predict:\t%4d' % pred)

7 命令行参数设置

设定命令行参数可以在不修改程序的情况下更改一些需要调整的参数,比如batch_size, resize 之后的image_size, epoch的值, 模式等等。写在程序的 import 部分后面即可。

parse = argparse.ArgumentParser(description='Params for training. ')
# 数据集根目录
parse.add_argument('--root', type=str, default='/home/chenyiran/character_rec/data', help='path to data set')
# 模式,3选1
parse.add_argument('--mode', type=str, default='train', choices=['train', 'validation', 'inference'])
# checkpoint 路径
parse.add_argument('--log_path', type=str, default=os.path.abspath('.') + '/log.pth', help='dir of checkpoints')

parse.add_argument('--restore', type=bool, default=True, help='whether to restore checkpoints')

parse.add_argument('--batch_size', type=int, default=16, help='size of mini-batch')
parse.add_argument('--image_size', type=int, default=64, help='resize image')
parse.add_argument('--epoch', type=int, default=100)
# 我的数据集类别数是3755,所以给定了一个选择范围
parse.add_argument('--num_class', type=int, default=100, choices=range(10, 3755))
args = parse.parse_args()

关于argpase的使用可以参考官方文档:https://docs.python.org/3.5/library/argparse.html

8 最后,主程序

很简单

if __name__ == '__main__':

    classes_txt(args.root + '/train', args.root + '/train.txt', num_class=args.num_class)
    classes_txt(args.root + '/test', args.root + '/test.txt', num_class=args.num_class)

    if args.mode == 'train':
        train()
    elif args.mode == 'validation':
        validation()
    elif args.mode == 'inference':
        inference()

以上就是所有的程序了。

直接下载请移步:

https://github.com/chenyr0021/Chinese_character_recognition/tree/master

知乎:@陈小白233
公众号:一本正经的搬砖日常

不点个赞再走嘛

  • 186
    点赞
  • 437
    收藏
    觉得还不错? 一键收藏
  • 122
    评论
评论 122
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值