基于Pytorch的手写汉字识别

1说明

这篇文章的代码是参照了另外一个大神的博客「Pytorch」CNN实现手写汉字识别(数据集制作,网络搭建,训练验证测试全部代码),非常感谢大神奉献的知识!
由于我是小白,其实大神的一些代码还是有些不太明白的地方,所以学习的过程中,在原博客的基础上加入我的一些理解与知识点的补充,有不对的地方希望可以指出
如果代码部分侵犯了原博客的相关权益,请联系我,我会删除的。

2数据集

数据是有一个train文件夹和一个test文件夹,每个文件夹中包含3755个小的文件夹,每个小的文件夹中包含若干个图片,这里的图片就是我们的数据啦!
train文件夹和test文件夹
在这里插入图片描述
在这里插入图片描述

需要说明的是,小的文件夹的文件名即为该文件夹内所有图片(数据)的类别标签。
如某个图片的路径是B:/pytorch/train/00000/12407.png,则00000即为该图片的类别标签。

了解数据后,接下来进入pytorch的世界!😀

3代码解析

再次申明,代码部分来自「Pytorch」CNN实现手写汉字识别(数据集制作,网络搭建,训练验证测试全部代码)

3.1导入库

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import argparse  # 提取命令行参数

这里每个库是干什么的,后面再详细介绍

3.2argparse模块——对参数进行设置

parse = argparse.ArgumentParser(description='Params for training. ')
# 数据集根目录,根据自己数据的位置进行修改,应该是root的下级文件夹就有train和test文件夹
parse.add_argument('--root', type=str, default='B:/pytorch', help='path to data set')
# 模式,3选1,当default为train时,就会运行下面的trian函数,就只进行模型训练
parse.add_argument('--mode', type=str, default='inference', choices=['train', 'validation', 'inference', 'inference_all'])
# checkpoint 路径,用来保存模型训练的状态
parse.add_argument('--log_path', type=str, default=os.path.abspath('.') + '/log.pth', help='dir of checkpoints')
# 第一次训练时,需要将restore设为False,等到代码运行一段时间生成log.pth文件后,停止训练,将其改为True,再继续运行即可
parse.add_argument('--restore', type=bool, default=True, help='whether to restore checkpoints')

parse.add_argument('--batch_size', type=int, default=16, help='size of mini-batch')
parse.add_argument('--image_size', type=int, default=64, help='resize image')
# 迭代次数
parse.add_argument('--epoch', type=int, default=100)
# 数据集类别数是3755,所以给定了一个选择范围,只对100个类别的数据进行训练和预测
parse.add_argument('--num_class', type=int, default=100, choices=range(10, 3755))
args = parse.parse_args()  # 解析参数

argparse模块是命令行选项,参数和子命令解析器,
可以轻松编写用户友好的命令行接口,
程序定义它需要的参数,argparse将弄清楚如何从sys.argv解析出那些参数。
会自动生成帮助和使用手册,并在用户给程序传入无效参数时报错。

上面是一段云里雾里的解释,我的理解是这个模块就是用来管理参数的,可以直接在这里设置下面代码块需要的参数。
具体使用步骤:

  1. 使用argparse.ArgumentParser(description='Params for training. ')创建一个解析器。这里的代码将解析器命名为parser。其中的description参数就是对这个解析器的一个描述。
  2. 使用parse.add_argument(’–num_class’, type=int, default=100, choices=range(10, 3755))往解析器里添加后面代码需要设置的参数。其中**’–num_class’**是参数名设置,实际上num_class是参数名,在这里需要在前面加上–符号,type为类型,default为默认值设置,choices为默认值可以选择的范围。
  3. 使用args = parse.parse_args()解析参数
  4. 在后面代码中需要引用这里的参数时,需要在参数前面加上’args.’,如后面会出现args.num_class等等。

对于代码中的参数设置,代码中也给出了注释,这里不再赘述。
不过还需要强调以下:第一次训练时,需要将restore设为False,等到代码运行一段时间生成log.pth文件后,停止训练,将其改为True,再继续运行即可!!!!

3.3提取图片路径

# 提取图片路径
def classes_txt(root, out_path, num_class=None):
    """

    write image paths (containing class name) into a txt file.
    :param root: data set path
    :param out_path: txt file path
    :param num_class: how many classes needed
    :return: None
    """
    dirs = os.listdir(root)  # 列出根目录下所有类别所在文件夹名
    if not num_class:		# 不指定类别数量就读取所有
        num_class = len(dirs)

    if not os.path.exists(out_path):  # 输出文件路径不存在就新建
        f = open(out_path, 'w')
        f.close()
    # 如果文件中本来就有一部分内容,只需要补充剩余部分
    # 如果文件中数据的类别数比需要的多就跳过
    with open(out_path, 'r+') as f:
        try:
            end = int(f.readlines()[-1].split('/')[-2]) + 1
        except:
            end = 0
        if end < num_class - 1:
            dirs.sort()
            dirs = dirs[end:num_class]
            for dir in dirs:
                files = os.listdir(os.path.join(root, dir))
                for file in files:
                    f.write(os.path.join("%s/%s/%s" % (root, dir, file)) + '\n')

这一部分的代码实际上就是将root目录下的文件夹名称写入一个txt文件中,在后面的主函数中会调用这个代码:

classes_txt(args.root + '/train', args.root + '/train.txt', num_class=args.num
  • 9
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值