0. 前言
至于CRNN网络的细节这里就不再多言了,网上有很多关于crnn的介绍,这里直接讲一下代码的实现流程
1. 数据集准备
CRNN是识别文本的网络,所以我们首先需要构建数据集,使用26个小写字母以及0到9十个数字,一共有36个字符,从这36个字符中随机选择4到9个字符(这里要说明一下,网上很多关于crnn的训练集中每张图片中的字符个数是一样的,这就具有很大的局限性。所以本文使用4到9随机选择字符个数构建图片。)
生成数据集代码如下:
import cv2
import numpy as np
import random
import imgaug.augmenters as iaa
def get_img():
zfu=['a','b','c','d','e','f','g','h','i','j','k','l','m','n',
'o','p','q','r','s','t','u','v','w','x','y','z',
'0','1','2','3','4','5','6','7','8','9']
# zfu=[str(i) for i in range(10)]
# zfu=[str(i) for i in range(10)]
k=random.randint(4,9)
select=random.choices(zfu,k=k)
lab=[zfu.index(i) for i in select]
select="".join(select)
font=cv2.FONT_HERSHEY_COMPLEX
src=np.ones(shape=(50,250,3)).astype('uint8')*255
src=cv2.putText(src,select,(20,27),font,1,(0,0,0),2)
seq = iaa.Sequential([
# iaa.Flipud(0.5), # flip up and down (vertical)
# iaa.Fliplr(0.5), # flip left and right (horizontal)
iaa.Multiply((0.5, 1.5)), # change brightness, doesn't affect BBs(bounding boxes)
iaa.GaussianBlur(sigma=(0, 1.0)), # 标准差为0到3之间的值
iaa.Crop(percent=(0, 0.06)),
iaa.Grayscale(alpha=(0, 1)),
iaa.Affine(
#translate_px={"x": (0, 15), "y": (0, 15)}, # 平移
scale=(0.95, 1.05), # 尺度变换
mode=iaa.ia.ALL,
cval=(100, 255)
),
iaa.Resize({
"height": 32, "width": 200})
])
# img是numpy格式,无归一化
src=np.expand_dims(src,axis=0)
src = seq(images=src)[0]
# cv2.imshow('a21',src)
# cv2.waitKey(0)
return src,lab
f_train=open('train.txt','w')
f_val=open('val.txt','w')
for i in range(10000):
img,lab=get_img()
lab=[str(i) for i in lab]
lab=" ".join(lab)
path='train_data/'+str(i)+'.jpg'
cv2.imwrite(path,img)
f_train.write(path+' '+lab+'\n')
print(i)
for i in range(1000):
img,lab=get_img()
lab=[str(i) for i in lab]
lab=" ".join(lab)
path='val_data/'+str(i)+'.jpg'
cv2.imwrite(path,img)
f_val.write(path+' '+lab+'\n')
print(i)
运行上述代码之前首先需要手动新建两个空文件夹用于存放训练图像和验证图像,文件夹名字分别是:train_data和val_data。运行完上述代码以后会在train_data文件夹中保存10000张训练图像,在val_data文件夹中保存1000张验证图像。此外还会生成两个txt文件,分别为train.txt和val.txt。
txt文本中存放的是图片的路径及包含字符的类别,如下所示:
部分训练图像如下所示:
2.构建网络
构建crnn网络的代码如下所示:
# crnn.py
import argparse, os
import torch
import torch.nn as nn
class BidirectionalLSTM(nn.Module):
def __init__(self, nInput_size, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.lstm = nn.LSTM(nInput_size, nHidden, bidirectional=True)
self.linear = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, (hidden, cell) = self.lstm(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.linear(t_rec) # [T * b, nOut]
output = output.view(T, b, -1) # 输出变换为[seq,batch,类别总数]
return output
class CNN(nn.Module):
def __init__(self, imageHeight, nChannel):
super(CNN, self).__init__()
assert imageHeight %