CRNN项目实战

CRNN项目实战

之前写过一篇文章利用CRNN进行文字识别,当时重点讲的CRNN网络结构和CNN部分的代码实现,因为缺少文字数据集没有进行真正的训练,这次正好有一批不定长的字符验证码,正好CRNN主要就是用于端到端地对不定长的文本序列进行识别,当然是字符和文字都是可以用的,所以这里进行了一次实战。

主要是参考github项目:https://github.com/meijieru/crnn.pytorch

关于lmdb

lmdb安装

首先关于lmdb这个数据库,python有两个包,一个是lmdb,另一个是python-lmdb。

使用pycharm的包安装功能可以看到关于lmdb的描述

Universal Python binding for the LMDB 'Lightning' Database Version 1.3.0

关于python-lmdb的描述

simple lmdb bindings written using ctypes Version 1.0.0

所以理论上我们安装前者肯定是可以用的,但是经过亲身实践,

在pip环境中使用pip install lmdb确实可以正常使用;

但是在conda环境中,使用conda install lmdb安装完成之后却无法导入包。

所以又使用:conda install python-lmdb安装,安装完之后却可以使用,非常奇怪。

后发现原因大概率是版本问题,使用pip可以安装lmdb=1.3.0的最新版本,而conda只能安装lmdb=0.9.x的版本,所以目前在conda中只能使用python-lmdb暂替使用。

制作适用CRNNlmdb数据集

github项目中关于如何训练自己的数据集写的不是很清楚,如果我们直接运行train.py会遇到各种问题,首先第一个问题就是数据集的问题,lmdbDataset中的初始化

self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False
        )

这里会报错,因为这里读取的路径下需要有lmdb格式的数据,所以在这之前我们需要生成lmdb格式的数据集。

相关代码如下:

# -*- coding: utf-8 -*-
import os
import lmdb  # install lmdb by "pip install lmdb"
import cv2
import glob
import numpy as np


def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return False
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True


def writeCache(env, cache):   # 在python3环境下运行
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            if type(v) is str:
                txn.put(k.encode(), v.encode())
                continue
            txn.put(k.encode(), v)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    assert (len(imagePathList) == len(labelList)), 'len(x) != len(y)'
    nSamples = len(imagePathList)
    print('...................')
    # env = lmdb.open(outputPath, map_size=104857600)  # 最大100MB
    env = lmdb.open(outputPath, map_size=10485760)

    cache = {}
    cnt = 1
    for i in range(nSamples):
        imagePath = imagePathList[i]
        label = labelList[i]
        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue

        # .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
        imageKey = f'image-{cnt}'
        labelKey = f'label-{cnt}'
        cache[imageKey] = imageBin
        cache[labelKey] = label

        if lexiconList:
            lexiconKey = f'lexicon-{cnt}'
            cache[lexiconKey] = ' '.join(lexiconList[i])
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt - 1
    cache['num-samples'] = str(nSamples)
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)


def read_text(path):
    with open(path) as f:
        text = f.read()
    text = text.strip()

    return text


if __name__ == '__main__':

    # lmdb 输出目录
    # outputPath = 'data/train/lmdb_data'  # 训练数据
    outputPath = 'data/val/lmdb_data'    # 验证数据

    # 训练图片路径,标签是txt格式,名字跟图片名字要一致,如123.jpg对应标签需要是123.txt
    # input_path = 'data/train/origin_data/*.jpg'
    input_path = 'data/val/origin_data/*.jpg'

    imagePathList = glob.glob(input_path)
    print('------------', len(imagePathList), '------------')
    imgLabelLists = []
    for p in imagePathList:
        try:
            # imgLabelLists.append((p, read_text(p.replace('.jpg', '.txt'))))
            imgLabelLists.append((p, p.split('_')[2].replace('.jpg', '')))
        except Exception as _e:
            print(_e)
            continue

    # sort by labelList
    imgLabelList = sorted(imgLabelLists, key=lambda x: len(x[1]))
    imgPaths = [p[0] for p in imgLabelList]
    txtLists = [p[1] for p in imgLabelList]

    createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True)

代码执行完成会在相应目录下生成data.mdb、lock.mdb两个文件。代码很简单,一看就懂,因为我的原始数据集,标签包含在图片名称中,不是另外存储在txt文件中,所以对相应代码进行了改动,另外把python2相关的东西改成了可以用python3运行。

另外有一点:

env = lmdb.open(outputPath, map_size=104857600) # 最大100MB

map_size需要根据自己的数据集设置大小(单位是B),运行完生成的data.mdb的大小就是设置的大小,如果设置的比较大造成空间的浪费,设置的比较小可能会不够用(默认应该是10MB)。

很多资料都写这里设置1T,如果你的电脑硬盘空间不够,就会报错(报的错误是乱码)。

另外,如果你还遇到了其他乱码报错,大概率是路径错误。

参考文章:https://www.cnblogs.com/yanghailin/p/14519525.html

CTCLoss

在train.py中有这么一行代码,

from warpctc_pytorch import CTCLoss

初次使用的话一般是显示没有这个包的,而pytorch(version>=1.1)其实是有CTCLoss模块的

from torch.nn import CTCLoss

所以如果你的pytorch版本满足,就无需额外安装warp_ctc_pytorch了,替换一下导入代码即可。如果你的版本比较低,还是需要手动安装这个包的,如果是Windows环境下,比较麻烦的就是需要安装cmake来编译文件。不再赘述。

需要用到的warp_ctc_pytorch: https://github.com/SeanNaren/warp-ctc

参考文章:https://blog.csdn.net/weixin_40437821/article/details/105473032

然后简单介绍下,pytorchCTCLoss的用法。

初始化

ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')

类初始化参数说明:

**blank:**空白标签所在的label值,默认为0,需要根据实际的标签定义进行设定

**reduction:**处理output losses的方式,string类型,可选’none’ 、 ‘mean’ 及 ‘sum’,'none’表示对output losses不做任何处理,‘mean’ 则对output losses取平均值处理,‘sum’则是对output losses求和处理,默认为’mean’ 。

计算损失

loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

CTCLoss()对象调用形参说明:

**log_probs:**shape为(T, N, C)的模型输出张量,其中,T表示CTCLoss的输入长度也即输出序列长度,N表示训练的batch size长度,C则表示包含有空白标签的所有要预测的字符集总长度,log_probs一般需要经过torch.nn.functional.log_softmax处理后再送入到CTCLoss中;

**targets:**shape为(N, S) 或(sum(target_lengths))的张量,其中第一种类型,N表示训练的batch size长度,S则为标签长度,第二种类型,则为所有标签长度之和,但是需要注意的是targets不能包含有空白标签;

**input_lengths:**shape为(N)的张量或元组,但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同;

**target_lengths:**shape为(N)的张量或元组,其每一个元素指示每个训练输入序列的标签长度,但标签长度是可以变化的;

这里最重要的就是初始化blank参数的设置和计算损失时,log_probs参数需要先进行log_softmax,这也是我们在这个项目中需要调整的点,如果我们直接从warp_ctc_pytorch更换为pytorch内置的CTCLoss,然后其他的不改动的话,是训练不出来结果的。

改动点:

criterion = CTCLoss(blank=0, reduction='mean')  # 初始化

cost = criterion(preds.log_softmax(2), text, preds_size, length) / batch_size  # 损失计算,共有两处,训练和验证

比较巧合的是,在这个项目中,0的位置就是为空白字符预留的,而且blank的默认值也为0,所以不改动也是可以的。

训练

配置训练数据路径(trainroot)、验证数据路径(valroot)、预训练权重路径(pretrained),将lr(学习率)设置为0.001,nepoch=200。

训练的过程中会报很多错误,因为这个GitHub项目可能部分代码写的比较粗糙,另一方面也是因为python2python3,Linux和windows的环境问题。

我遇到了以下错误:

1、trainRoot,valRoot需要改下大小写

2、TypeError: Won't implicitly convert Unicode to bytes; use .encode()

按照错误提示加上encode
txn.get(‘num-samples’.encode())
label_byte = txn.get(label_key.encode())
imgbuf = txn.get(img_key.encode())

3、ValueError: sampler option is mutually exclusive with shuffle

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=opt.batchSize,
    shuffle=False,  # sampler不为None,shuffle就需要为False
    sampler=sampler,
    num_workers=int(opt.workers),
    collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio)
)

因为本来shuffle参数为True,需要更改为False。

参考文章:https://blog.csdn.net/hjxu2016/article/details/111300972

4、TypeError: cannot pickle 'Environment' object

这个是因为在windows环境下,多进程训练有问题,将workers参数设置为0即可。

参考文章:https://blog.csdn.net/weixin_43272781/article/details/112757371

5、AttributeError: module 'torch' has no attribute 'longTensor'

在dataset.py脚本中,应该是torch.LongTensor。

6、TypeError: randint() takes 3 positional arguments but 4 were given

random.randint(0, len(self), self.batch_size)

我猜这里是笔误了,应该是len(self)-self.batch_size。

7、train.py中

image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)

推测最后一个参数应该是opt.imgW。

8、RuntimeError: The expanded size of the tensor (64) must match the existing size (63) at non-singleton dimension 0. Target sizes: [64]. Tensor sizes: [63]

嘿嘿,这个问题是因为我改了torch.range代码,因为pycharm提示说这个方法被废弃了,要求使用torch.arange

batch_index = random_start + torch.arange(0, self.batch_size-1)
index[i * self.batch_size:(i+1)*self.batch_size] = batch_index

torch.range(start=1, end=6) 的结果是会包含end的,而torch.arange(start=1, end=6)的结果并不包含end。所以这里就不需要减1了。

batch_index = random_start + torch.arange(0, self.batch_size)

同样的,后面取末尾元素的时候也要去掉减1

tail_index = random_start + torch.arange(0, tail)

参考文章:https://blog.csdn.net/lunhuicnm/article/details/106712026

9、RuntimeError: set_sizes_contiguous is not allowed on a Tensor created from .data or .detach().

v.data.resize_(data.size()).copy_(data)
v.resize_(data.size()).copy_(data)

将上面的替换为下面的

参考文章:https://blog.csdn.net/weixin_45292103/article/details/102736742

10、还有一个关于utils.py中编码器(encode()方法的问题),

本来有一行代码是:_str = unicode(_str, 'utf-8')

这是python2中的语法,python3并不需要,但是因为前面制作lmdb数据集的时候,我们的标签进行了encode()处理,也就是这一步:txn.put(k.encode(), v.encode())

导致后面训练的时候,在对标签进行编码时,标签传过来是这样一种形式:“b’jdvfl0k’”

所以可以加上这行代码:_str = _str.replace("b'", "").replace("'", '')

11、在train.py的验证方法里,有这么一段代码:

_, preds = preds.max(2)
preds = preds.squeeze(2) 
preds = preds.transpose(1, 0).contiguous().view(-1)

我在实际运行时发现preds.squeeze(2)会报错,然后调试发现preds此时的shape为(26, 64),所以应该无需squeeze,可以将这行代码注释。

另外,还要根据情况配置displayInterval、valInterval、saveInterval参数

displayInterval默认值是500,我这里训练不定长字符验证码,训练集总共4500张图片,batch_size=64,总共71个批次,所以设置displayInterval=10,每10个批次打印一次损失情况。其他两个参数同理,按照自己的情况调整。

终于把所有问题都解决了,可以正常训练了,但是训练的过程中打印的测试数据让我感觉不太对劲,我发现在对比预测标签和真实标签时,真实标签的形式为:“b’jdvfl0k’”

和上面同样的问题,需要处理下:

target = target.replace("b'", "").replace("'", "")
gt = gt.replace("b'", "").replace("'", "")

因为使用的预训练权重,训练的比较快,训练过程示例:

[0/200][10/71] Loss: 0.07960973680019379
[0/200][20/71] Loss: 0.014807680621743202
[0/200][30/71] Loss: 0.010894499719142914
Start val
66----y----f----g----p---- => 6yfgp, gt:6yfgp
-f-t--r---z---x---d---8--- => ftrzxd8, gt:ftpzxd8
-k--a----y---m-----w------ => kaymw, gt:kaymw
-n-----3-----7-----z------ => n37z, gt:n37z
--u-------y----l----9----- => uyl9, gt:uyl9
ss---y---8---e---l----p--- => sy8elp, gt:sy8elp
--y--d--h--t--f--m----y--- => ydhtfmy, gt:ydtfmy
--zz----zz----l-----1----- => zzl1, gt:zzl1
Test loss: 0.016945159062743187, accuracy: 0.465
[0/200][40/71] Loss: 0.007085741963237524
[0/200][50/71] Loss: 0.007241038139909506
[0/200][60/71] Loss: 0.004275097511708736
[0/200][70/71] Loss: 0.0034677726216614246
Start val
-n----n----3----k----l---- => nn3kl, gt:nn3kl
-c---c----b---s---a---k--- => ccbsak, gt:ccbsak
-u---v---7--k---z--s--4--- => uv7kzs4, gt:uv7kzs4
-5----e------u------------ => 5eu, gt:5eu
-k-----e-----v-------v---- => kevv, gt:kelbv
--yy-d-----t--f--m----y--- => ydtfmy, gt:ydtfmy
-5--q---h---t--m---8--q--- => 5qhtm8q, gt:5qhtm8q
--y-----z----g-------v---- => yzgv, gt:yzgv
Test loss: 0.014375979080796242, accuracy: 0.66
[1/200][10/71] Loss: 0.004606468137353659
[1/200][20/71] Loss: 0.0033515722025185823
[1/200][30/71] Loss: 0.002877553692087531
Start val
-7--d----o---h----c---y--- => 7dohcy, gt:7d0hcy
-6----y----f----g----p---- => 6yfgp, gt:6yfgp
-1----f---j----b-----f---- => 1fjbf, gt:1fjbf
-2---o----u---h--4----z--- => 2ouh4z, gt:2ouh4z
--x----1---b---ww-----n--- => x1bwn, gt:x1bwn
-j---k--------------55---- => jk5, gt:jk15
-1---t--6---x----n----o--- => 1t6xno, gt:1t6xno
--y---x--j---k----a----t-- => yxjkat, gt:yxjkat
Test loss: 0.012275747954845428, accuracy: 0.745
[1/200][40/71] Loss: 0.0018801590194925666
[1/200][50/71] Loss: 0.002281028078868985
[1/200][60/71] Loss: 0.001854069298133254
[1/200][70/71] Loss: 0.0012131230905652046
Start val
dd----y--j-c---v--y---m--- => dyjcvym, gt:dyjcvym
-s----z---9----qq------t-- => sz9qt, gt:sz9qt
--y--qq----qq-----t---h--- => yqqth, gt:yqqth
-h---a----5--s----xx--o--- => ha5sxo, gt:ha5sxo
tt----o----7----a---0----- => to7a0, gt:to4ao
-n-----xx---f---e-----s--- => nxfes, gt:nxfes
-1---t--6---x----n----o--- => 1t6xno, gt:1t6xno
-c---e---d---e---c---m---- => cedecm, gt:cedecm
Test loss: 0.004259577952325344, accuracy: 0.815

两轮训练结束之后准确率就达到了81.5%(200个验证图片)

经过35轮的训练,准确率可以稳定在97.5%左右。

我的gtihub博客地址:https://forchenxi.github.io/

另外,如果对投资理财感兴趣的同学,可以关注我的微信公众号:运气与实力。

评论 30
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值