基于CRNN实现验证码识别

最近在做的一个小项目,进行简单的分享。现在不论是手机APP、网页、电脑客户端等,只要涉及用户登录的界面基本都需要输入验证码核实身份真伪。因此有时,我们也会有自动识别验证码的需求,例如:希望实现业务流程自动化时,用户登录作为流程第一步可能就会用到验证码的识别,将该模型部署为接口进行调用就可以完成该功能。

本项目基于非常经典的文本识别算法CRNN来进行验证码识别模型的训练,整个流程是基于PaddlePaddle来训练的,后边也会放上Pytorch版本的模型代码。项目中代码直接使用时,需要自己加上 import python包 的部分。

1. 数据准备

1.1 数据集准备

这次分享的项目其实是一个简单的示例,使用一个比较简单的大写字母加数字的数据集进行训练,因此字符的丰富度可能比较有限,在真实场景下可以基于自己的需要找更加合适的数据集进行训练。这里把我找数据集的链接发给大家,就不放自己的数据集了,把大概准备数据的过程分享给大家。

1.2 准备标签文件

准备标签文件,这里因为我的数据集图片命名就是label的值,所以准备标签的方式比较简单。这里只是一种比较简单的数据准备方式,大家可以根据自己数据的情况进行标注和标签文件的准备。

#生成总的标签文件
train_path = "pic"
SUM = []
for root,dirs, files in os.walk(train_path): # 分别代表根目录、文件夹、文件
    for file in files:              
        imgpath = os.path.join(root, file)
        SUM.append(imgpath+"\t"+file.split(".")[0]+"\n")
    # 生成总标签文件
    allstr = ''.join(SUM)
    f = open('total_list.txt','w',encoding='utf-8')
    f.write(allstr)
    f.close
print("数据集数量:{}".format(len(SUM)))

生成总的标签文件后就可以划分训练集和验证集,训练集和验证集的比例也可以自己去定。

random.shuffle(SUM)
train_len = int(len(SUM) * 0.8)
test_list = SUM[:train_len]
train_list = SUM[:train_len]
print('训练集数量: {}, 验证集数量: {}'.format(len(train_list),len(test_list)))
#生成训练集的标签文件
train_txt = ''.join(train_list)
f_train = open('train_list.txt','w',encoding='utf-8')
f_train.write(train_txt)
f_train.close()
#生成测试集的标签文件
test_txt = ''.join(test_list)
f_test = open('test_list.txt','w',encoding='utf-8')
f_test.write(test_txt)
f_test.close()

1.3 准备数据字典

在OCR-文本识别任务中,有一个特别需要准备的文件就是字典。文本识别的结果最终包含于字典文件中的字符集,也就是字典文件中有的字符才有可能作为最终识别的结果,没有的字符也就不会作为结果进行输出。在这个项目里,字典中的字符集也就应该是所有大写字母加上数字的集合。

#准备字典
class_set = set()
lines = []
file = open("total_list.txt","r",encoding="utf-8")#待转换文档,这里我们使用的是数据集的标签文件
for i in file:
    a=i.strip('\n').split('\t')[-1]
    lines.append(a)
file.close
for line in lines:
    for e in line:
        class_set.add(e)
class_list = list(class_set)
class_list.sort()
print("class num: {0}".format(len(class_list)))
with codecs.open("new_dict.txt", "w", encoding='utf-8') as label_list:
    for id, c in enumerate(class_list):
        label_list.write("{0}\n".format(c))

1.4 可视化观察一张样本

img = Image.open('9APK.png')
img = np.array(img)

# 画出读取的图片
plt.figure(figsize=(10, 10))
plt.imshow(img)

在这里插入图片描述

2. 数据预处理

在数据灌入模型前,需要对数据进行预处理操作,使得图片和标签满足网络训练和预测的需要。这里简单实现了如下方法:

  • 图像解码:将图像转为Numpy格式;
  • 编码标签:将标签按照CTC(Connectionist temporal classification)算法要求进行编码。其中,字符串中每个字符替换为其在字符字典中的索引值,规定标签的最大长度max_text_len,如果标签中字符个数小于max_text_len,则剩余位置补0,例如规定max_text_len=10,标签为[2322],字符字典为[0,1,2,3,4,5,6,7,8,9],则编码后的标签为[2,3,2,2,0,0,0,0,0,0];
  • 缩放图像并归一化:将原图片的高度统一缩放到32,归一化后贴在尺寸为[3,32,100]的空白画布上;
  • 返回图像、标签、长度:将保存在字典中的数据取出,以列表的形式返回,列表中元素顺序分别为 image, label, length。

图像解码

class DecodeImage(object):
    # 图像解码
    def __init__(self, img_mode='BGR', channel_first=False):
        self.img_mode = img_mode
        self.channel_first = channel_first

    def __call__(self, data):
        # 解码图像并返回结果
        img = data['image']
        img = np.frombuffer(img, dtype='uint8')
        img = cv2.imdecode(img, 1)
        if img is None:
            return None
        if self.img_mode == 'GRAY':
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        elif self.img_mode == 'RGB':
            assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
            img = img[:, :, ::-1]

        if self.channel_first:
            img = img.transpose((2, 0, 1))

        data['image'] = img
        return data

编码标签

将字符格式的标签转换为索引格式,如果不足 max_text_len 个,则在最后进行补零。

def encode(text, max_text_len, dict_index):
    # 将字符标签转换为对应的索引值
    # 如果没有字符或字符个数超过上限,返回None
    if len(text) == 0 or len(text) > max_text_len:
        return None
    # 将字符的索引值依次保存到text_list
    text_list = []
    for char in text:
        # 如果字符在字符字典没有出现,不进行保存
        if char not in dict_index:
            continue
        text_list.append(dict_index[char])
    if len(text_list) == 0:
        return None
    return text_list
class CTCLabelEncode(object):
    # 编码标签
    def __init__(self, max_text_length=25, character_dict_path='new_dict.txt'):
        self.max_text_length = max_text_length
        # 将标签编码为CTC格式
        character_str = ""
        # 读取字符字典
        with open(character_dict_path, "rb") as fin:
            lines = fin.readlines()
            for line in lines:
                line = line.decode('utf-8').strip("\n").strip("\r\n")
                character_str += line
        dict_character = list(character_str)
        # 添加类别:分隔符
        dict_character = ['blank'] + dict_character
        # 将每个类别对应的索引保存到字典中
        self.dict_index = {
   }
        for i, char in enumerate(dict_character):
            self.dict_index[char] = i

    def __call__(self, data):
        # 获取数据的标签
        text = data['label']
        # 将标签转换为索引
        text = encode(text, self.max_text_length, self.dict_index)
        if text is None:
            return None
        data['length'] = np.array(len(text))
        text = text + [0] * (self.max_text_length - len(text))
        data['label'] = np.array(text)
        return data

缩放图像并标准化

class RecResizeImg(object):
    def __init__(self, image_shape=[3, 32, 100]):
        self.image_shape = image_shape

    def __call__(self, data):
        img = data['image']
        norm_img = self.resize_norm_img(img, self.image_shape)
        data['image'] = norm_img
        return data
    
    def resize_norm_img(self, img, image_shape):
        # 缩放图像并对图像进行归一化
        # 缩放图像
        imgC, imgH, imgW = image_shape
        h = img.shape[0]
        w = img.shape[1]
        ratio = w / float(h)
        # 如果 w 大于等于100,令 resized_w = 100
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        # 如果 w 小于100,令 resized_w 为 w 向上取整
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        # 对图片进行归一化
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        # 新建大小为[3, 32, 100]的空白图像,将缩放后的图像贴到对应位置,其他位置补0
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

返回图像、标签、长度

class KeepKeys(object):
    # 将字典格式的数据转换为列表格式返回
    def __init__(self, keep_keys=['image', 'label', 'length']):
        self.keep_keys = keep_keys

    def __call__(self, data):
        data_list = []
        for key in self.keep_keys:
            data_list.append(data[key])
        return data_list

汇总上述方法

# 图像预处理方法汇总
def transform(data, mode='train'):
    # 图像解码
    decode_image = DecodeImage()
    # 编码标签
    encode_label = CTCLabelEncode()
    # 缩放图像并标准化
    resize_image = RecResizeImg()
    data = decode_image(data)
    if mode == 'train' or mode == 'val':
        data = encode_label(data)
        keep_keys=['image', 'label', 'length']
    else:
        keep_keys = ['image']
    # 返回图像、标签、长度
    keepkeys = KeepKeys(keep_keys=keep_keys)
    data = resize_image(data)
    data = keepkeys(data)
    return data

定义数据读取类SimpleDataSet,实现数据批量读取和预处理。具体代码如下:

class SimpleDataSet(Dataset):
    def __init__(self, mode, label_file, data_dir, seed=None):
        super(SimpleDataSet, self).__init__()
        self.mode = mode.lower()
        # 标注文件中,使用'\t'作为分隔符区分图片名称与标签
        self.delimiter = '\t'
        # 数据集路径
        self.data_dir = data_dir
        # 随机数种子
        self.seed = seed
        # 获取所有数据,以列表形式返回
        self.data_lines = self.get_image_info_list(label_file)
        # 新建列表存放数据索引
        self.data_idx_order_list = list(range(
  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值