CRNN文本识别与tensorflow实现

1.引言

    文本识别即对一张文本图像进行识别,将其中的文字转化为文本信息,这样才能变成计算机可以理解的语言。前面我们介绍了两种文本检测方法,请参见《CTPN文本检测与tensorflow实现》《EAST文本检测与Keras实现》,在文本检测之后,我们可以获得了一张图像中各个文本的位置,这时,我们可以将各个文本片段剪切出来,进行仿射变换,得到类似图1这样的文本图像,但是,这时计算机还是没法理解图像中具体是什么文字,因此,需要进行文本识别,即将图像中的文本转化为纯文本,我们平时见到的验证码识别其实也是文字识别的一种场景。

图1 从自然场景图像中剪切出来的文本片段

    在以往的文本识别模型中,习惯是采用一种滑动窗口的方式,逐步检测每个窗口下的文本,这种做法对于不同的字体、字体检测效果就特别差,特别对于中文文字的识别。然后也有一些模型采用对齐的方式,对图像的每一帧都进行文本标注,然后采用类似encoder-decoder这样的结构来进行文本识别,但是这样的做法需要耗费大量的人力进行对齐标注,特别是当文本前后带有空白字符时,标注起来就特别繁琐。因此,文本将介绍一个在文本识别中效果相对比较好的模型——CRNN,该模型不需要对图像进行对齐标注 ,直接输入文本图像,然后就可以输出对应的识别结果,而且准确率非常高!

2.模型介绍

2.1 模型结构介绍

    CRNN的模型结构总共包含三部分,分别是卷积层、RNN层和转录层,如图2所示。

图2 CRNN模型结构

    在卷积层部分,首先将每一张图像的高度固定在某一个值,然后对图像进行卷积操作,接着,对于卷积后得到的feature maps构建RNN层的输入特征序列,具体的操作就是,将这些feature maps从左到右每次取出一列,然后将每个feature map对应该列的向量进行拼接,拼接后的向量就作为RNN该时间步对应的特征输入。由于卷积后得到的feature maps每一列都对应原图的一个矩形区域,因此,按照这种操作得到的feature Sequence中每一个向量其实也是与原图的某个矩形区域相对应,并且这些矩形区域也是按照从左到右顺序排列的,因此,每个特征向量之间其实是带有时序关系的。如图3所示。

图3 卷积层得到的特征序列与原图区域的对应关系

     接着,是模型的RNN层部分,由前面我们知道,卷积层结束后得到的feature Sequence中,每个向量之间是具有时序关系的,不是独立的,因此,很自然就会想到用RNN来操作,作者在论文中采用的是深层双向递归神经网络,其中RNN单元采用的是LSTM单元,如图4所示。引入RNN主要有三个好处:①有些比较大的字符同时横跨多列,采用RNN可以记住前面序列的信息,另外,有些字符放在一起时,可以进行高度对比,更容易识别出其标签,比如‘i’和‘l’。②RNN可以将误差传递给CNN层,从而使得模型可以同时训练RNN和CNN的参数。③RNN可以解决文本序列变长的问题。

图4 LSTM单元和深层双向RNN

    假设在卷积层得到的feature Sequence为\mathbf { x } = x _ { 1 } , \dots , x _ { T },则对于每个时间步的输入x _ { t },RNN将输出该时间步对应的类别分布y _ { t },其中y _ { t }的长度即为所有字符类别的长度。记RNN层得到的输出序列为y= y _ { 1 } , \dots , y _ { T },其中T为序列的长度,其中,y _ { t } \in \Re ^ { \left| \mathcal { L } ^ { \prime } \right| }表示第t个时间步的字符类别概率分布,\mathcal { L } ^ { \prime } =\mathcal { L } \cup表示所有字符类别和空字符的集合。这里可能有人会觉得,既然已经输出了各个时间步的输出,那么可不可以像机器翻译那样,直接对输出序列的前后标记start和end字符,然后从输出里面进行截取,获得预测的标签序列,这么想是可以的,不过呢,就需要人为对整个图像每个时间步对应的感受野事先标记好其标签,会产生很繁琐的手工标注工作,因此,作者并没有这样操作,而是采用了一种转录方法,即模型中的转录层。

    在转录层,作者引入了一个\mathcal { B }变换,即对于一个字符序列\pi \in \mathcal { L } ^ { \prime T }\mathcal { B }变换会将其中的重复字符、空字符移除,得到最后的字符序列l,比如对于预测序列“--hh-e-l-ll-oo--”,其中“-”表示空字符,则经过\mathcal { B }变换后得到的输出为“hello”,这里需要注意的是,当两个字符相同,并且中间隔着“-”时,则去重时不移除,因此,l的条件概率即为那些经过\mathcal { B }变换后得到l的字符序列\pi的概率加总,具体表达式如下:

                                                           p ( l | \mathbf { y } ) = \sum _ { \boldsymbol { \pi } : \mathcal { B } ( \boldsymbol { \pi } ) = l } p ( \pi | \mathbf { y } )

其中,p ( \pi | y ) =\prod _ { t = 1 } ^ { T } y _ { \pi _ { t } } ^ { t }为每个字符序列中每个字符概率的乘积,y _ { \pi _ { t } } ^ { t }表示第t个时间步为字符\pi _ { t }的概率,但是,这种算法将非常耗时,因此,作者借鉴了CTC中的forward-backward的算法使其更有效率。

    关于CTC中forward-backward的算法原理介绍可以参见我另一篇博文《CTC原理介绍》,这里不再具体展开。

    转录的时候有两种方式,一种是无词典的转录方式,一种是基于词典的转录方式。

    对于无词典的转录方式,其计算公式如下:

                                                           l ^ { * } \approx \mathcal { B } \left( \arg \max _ { \pi } p ( \pi | \mathbf { y } ) \right)

其实就是对每个时间步选择概率最大的字符,最后将该字符序列用\mathcal { B }变换得到对应的l

    对于基于词典的转录方式,其思想是构建一个词典集,然后计算词典中每个字符序列的概率,从中选择概率最大的作为最终的转录文本,其计算公式如下:

                                                           l ^ { * } =\arg \max _ { \mathrm { l } \in \mathcal { D } } p ( \mathrm { l } | \mathrm { y } )

其中,\mathcal{D}即为构建的词典集,基于这种计算方法有个缺点,就是当词典集比较大时,计算复杂度比较大,因此,作者提出了一种改进方法,作者发现基于无词典的转录方式其实与真实的标签很接近,因此,作者首先采用无词典的转录方式获得转录文本l ^ { \prime },然后用BK-tree从词典集中搜索与它编辑距离(有关编辑距离的概念可以参考这篇文章:《Edit Distance(编辑距离)》)小于\delta的词典,记为\mathcal { N } _ { \delta } \left( \mathrm { l } ^ { \prime } \right),然后再从近邻词典里面计算每个字符序列的概率,选择概率最大的作为最后的转录文本,其计算公式如下:

                                                          \mathrm { l} ^ { * } = \arg \max _ { \mathrm {l} \in \mathcal { N } _ { \delta } \left( \mathrm { l} ^ { \prime } \right) } p ( \mathrm { l } | \mathrm { y } )

2.2 模型的损失函数

    CRNN的损失函数采用的是负对数似然函数,记训练集为\mathcal { X } = \left\{ I _ { i } , l _ { i } \right\} _ { i },其中,I _ { i }表示输入的图像,l _ { i }表示真实的字符序列,则对应的损失函数为:

                                                         \mathcal { O } = - \sum _ { I _ { i } , \mathbf { l } _ { i } \in \mathcal { X } } \log p \left( \mathbf { l } _ { i } | \mathbf { y } _ { i } \right) 

3.tensorflow实现

    本文采用tensorflow对CRNN原理进行复现,项目的结构如图5所示,下面将对每个模块进行具体介绍。

图5 项目结构

    首先是data路径,存放的是训练集和测试集,train_images存放的是训练时的数据集,test_images存放的是测试时的数据集,本文的数据有两种来源,一种是ICPR比赛数据集,一种是模拟的数据集。

图6 data路径下结构

     dict下存放的是字符集文档,有三种可以选择,chinese.txt存放的是中文常用3000字,english.txt存放的是英文字母以及一些标点符号,而english_chinese.txt则是前面两个文档的集合,当选择english_chinese.txt时,则支持对中英文的文本识别,本文训练时默认使用的是english_chinese.txt。

图7 字符集合文档

    fonts路径存放的是生成模拟数据时的字体文件,window系统一般可以在C:\Windows\Fonts下查找,这个可以自己选择字体文件。

图8 字体文件

     images_base存放的是模拟数据的背景图像,models文件夹存放的是训练后的模型文件。接着,是各个py脚本文件的功能介绍,其中,charset_generate.py,该脚本存放的是字符集文本生成函数,从图像的label中提取字符集合,存生成charset.txt存放在data路径下。其代码如下:

import tqdm
from crnn import config as crnn_config


def generate_charset(labels_path, charset_path):
    """
    generate char dictionary with text label
    :param labels_path:label_path: path of your text label
    :param charset_path: path for restore char dict
    :return:
    """
    with open(labels_path, 'r', encoding='utf-8') as fr:
        lines = fr.read().split('\n')
    dic = str()
    for label in tqdm.tqdm(lines[:-1]):
        for char in label:
            if char in dic:
                continue
            else:
                dic += char
    with open(charset_path, 'w', encoding='utf-8')as fw:
        fw.write(dic)


if __name__ == '__main__':
    label_path = crnn_config.train_label_path
    char_dict_path = crnn_config.charset_path
    generate_charset(label_path, char_dict_path)

    然后是data_provider.py文件,该文件一方面用于从自然场景图像中对文本进行切割,然后进行放射片段,并保存到data下的训练集和测试集路径下,用于训练和测试时使用,另一方面用于生成模拟的数据,模拟的数据也同样会存放在训练集路劲下。

import os
import cv2
import math
import random
import shutil
import numpy as np
from tqdm import trange
from collections import Counter
from crnn import charset_generate
from multiprocessing import Process
from crnn import config as crnn_config
from PIL import Image, ImageDraw, ImageFont


class TextCut(object):
    def __init__(self,
                 org_images_path,
                 org_labels_path,
                 cut_train_images_path,
                 cut_train_labels_path,
                 cut_test_images_path,
                 cut_test_labels_path,
                 train_test_ratio=0.8,
                 filter_ratio=1.5,
                 filter_height=25,
                 is_transform=True,
                 angle_range=[-15.0, 15.0],
                 write_mode='w',
                 use_blank=False,
                 num_process=1):
        """
            对ICPR原始图像进行切图
            :param org_images_path: ICPR数据集原始图像路径,[str]
            :param org_labels_path: ICPR数据集原始label路径,[str]
            :param cut_train_images_path: 训练集切图的保存路径,[str]
            :param cut_train_labels_path: 训练集切图对应label的保存路径,[str]
            :param cut_test_images_path: 测试集切图的保存路径,[str]
            :param cut_test_labels_path: 测试集切图对应label的保存路径,[str]
            :param train_test_ratio: 训练测试数据集比例,[float]
            :param filter_ratio: 图片过滤的高宽比例,高于该比例的图片将被过滤,default:1.5 ,[float]
            :param filter_height:高度过滤,切图后的图像高度低于该值的将被过滤掉,[int]
            :param is_transform: 是否进行仿射变换,default:True [boolean]
            :param angle_range: 不进行仿射变换的角度范围default:[-15.0, 15.0],[list]
            :param write_mode: 数据写入模式,'w':write,'a':add,[str]
            :param use_blank: 是否使用空格,[boolean]
            :param num_process: 并行处理的进程数
            :return:
        """
        self.org_images_path = org_images_path
        self.org_labels_path = org_labels_path
        self.cut_train_images_path = cut_train_images_path
        self.cut_train_labels_path = cut_train_labels_path
        self.cut_test_images_path = cut_test_images_path
        self.cut_test_labels_path = cut_test_labels_path
        self.train_test_ratio = train_test_ratio
        self.filter_ratio = filter_ratio
        self.filter_height = filter_height
        self.is_transform = is_transform
        self.angle_range = angle_range
        assert write_mode in ['w', 'a'], "write mode should be 'w'(write) or 'a'(add)"
        self.write_mode = write_mode
        self.use_blank = use_blank
        self.num_process = num_process
        self.org_labels_list = None
        super().__init__()

    def data_load(self, org_images_list):
        """
        对ICPR图像做文本切割处理
        :param org_images_list: 原始图片文件名
        :return:
        """
        data_len = len(org_images_list)
        train_test_offset = data_len * self.train_test_ratio
        for data_i in range(len(org_images_list)):
            org_image_path = org_images_list[data_i]
            org_image_name = os.path.basename(org_image_path)[:-4]
            org_label_path = org_image_name + ".txt"
            if org_label_path not in self.org_labels_list:
                continue
            org_image = Image.open(os.path.join(self.org_images_path, org_image_path))
            with open(os.path.join(self.org_labels_path, org_label_path), 'r', encoding='utf-8') as fr:
                org_label = fr.read().split('\n')
            cut_images_list, cut_labels_list = self.cut_text(org_image, org_label,
                                                             self.filter_ratio,
                                                             self.is_transform,
                                                             self.angle_range)
            if data_i < train_test_offset:
                img_save_path = self.cut_train_images_path
                label_save_path = self.cut_train_labels_path
            else:
                img_save_path = self.cut_test_images_path
                label_save_path = self.cut_test_labels_path
            for i in range(len(cut_images_list)):
                cut_img = cut_images_list[i]
                if cut_img.shape[0] >= self.filter_height:
                    cut_img = Image.fromarray(cut_img)
                    cut_img = cut_img.convert('RGB')
                    cut_label = cut_labels_list[i]
                    cut_img_name = org_image_name + '_' + str(i) + '.jpg'
                    cut_img.save(os.path.join(img_save_path, cut_img_name))
                    with open(label_save_path, 'a', encoding='utf-8') as fa:
                        fa.write(cut_img_name + '\t' + cut_label + '\n')

    def data_load_multi_process(self, num_process=None):
        """
        多进程对ICPR图像做文本切割处理
        :param num_process:进程数,默认16,[int]
        :return:
        """
        if num_process is None:
            num_process = self.num_process
        org_images_list = os.listdir(self.org_images_path)
        self.org_labels_list = os.listdir(self.org_labels_path)
        # clear label.txt at first step
        check_path([self.cut_train_images_path,
                    self.cut_train_labels_path,
                    self.cut_test_images_path,
                    self.cut_test_labels_path])
        if self.write_mode == 'w':
            clear_content([self.cut_train_images_path,
                           self.cut_train_labels_path,
                           self.cut_test_images_path,
                           self.cut_test_labels_path])
        all_data_len = len(org_images_list)
        data_offset = all_data_len // num_process
        processes = list()
        for data_i in trange(0, all_data_len, data_offset):
            if data_i + data_offset >= all_data_len:
                processes.append(Process(target=self.data_load, args=(org_images_list[data_i:],)))
            else:
                processes.append(Process(target=self.data_load, args=(org_images_list[data_i:data_i + data_offset],)))
        for process in processes:
            process.start()
        for process in processes:
            process.join()

    def cut_text(self, image, labels, filter_ratio, is_transform, angle_range):
        """
        文本切图
        :param image: 原始图像,[array]
        :param labels: 文本的label,[str]
        :param filter_ratio: 图片过滤的高宽比例,高于该比例的图片将被过滤,e.g. 1.5 ,[float]
        :param is_transform: 是否进行仿射变换,[boolean]
        :param angle_range: 不进行仿射变换的角度范围e.g.[-15.0, 15.0],[list]
        :return:
        """
        cut_images = list()
        cut_labels = list()
        w, h = image.size
        for label in labels:
            if label == '':
                continue
            label_text = label.split(',')
            text = label_text[-1]
            if not self.use_blank:
                text = text.replace(' ', '')
            if text == '###' or text == '★' or text == '':
                continue
            position = self.reorder_vertexes(
                np.array([[round(float(label_text[i])), round(float(label_text[i + 1]))] for i in range(0, 8, 2)]))
            position = np.reshape(position, 8).tolist()
            left = max(min([position[i] for i in range(0, 8, 2)]), 0)
            right = min(max([position[i] for i in range(0, 8, 2)]), w)
            top = max(min([position[i] for i in range(1, 8, 2)]), 0)
            bottom = min(max([position[i] for i in range(1, 8, 2)]), h)
            if (bottom - top) / (right - left + 1e-3) > filter_ratio:
                continue
            image = np.asarray(image)
            cut_image = image[top:bottom, left:right]
            if is_transform:
                trans_img = self.transform(image, position, angle_range)
                if trans_img is not None:
                    cut_image = trans_img
            cut_images.append(cut_image)
            cut_labels.append(text)
        return cut_images, cut_labels

    def transform(self, image, position, angle_range):
        """
        仿射变换
        :param image: 原始图像,[array]
        :param position: 文本所在的位置e.g.[x0,y0,x1,y1,x2,y2],[list]
        :param angle_range: 不进行仿射变换的角度范围e.g.[-15.0, 15.0],[list]
        :return: 变换后的图像
        """
        from_points = [position[2:4], position[4:6]]
        width = round(float(self.calc_dis(position[2:4], position[4:6])))
        height = round(float(self.calc_dis(position[2:4], position[0:2])))
        to_points = [[0, 0], [width, 0]]
        from_mat = self.list2col_matrix(from_points)
        to_mat = self.list2col_matrix(to_points)
        tran_m, tran_b = self.get_transform(from_mat, to_mat)
        probe_vec = np.matrix([1.0, 0.0]).transpose()
        probe_vec = tran_m * probe_vec
        scale = np.linalg.norm(probe_vec)
        angle = 180.0 / np.pi * math.atan2(probe_vec[1, 0], probe_vec[0, 0])
        if (angle > angle_range[0]) and (angle < angle_range[1]):
            return None
        else:
            from_center = position[2:4]
            to_center = [0, 0]
            dx = to_center[0] - from_center[0]
            dy = to_center[1] - from_center[1]
            trans_m = cv2.getRotationMatrix2D((from_center[0], from_center[1]), -1 * angle, scale)
            trans_m[0][2] += dx
            trans_m[1][2] += dy
            dst = cv2.warpAffine(image, trans_m, (int(width), int(height)))
            return dst

    def get_transform(self, from_shape, to_shape):
        """
        计算变换矩阵A,使得y=A*x
        :param from_shape: 变换之前的形状x,形式为矩阵,[list]
        :param to_shape: 变换之后的形状y,形式为矩阵,[list]
        :return: A
        """
        assert from_shape.shape[0] == to_shape.shape[0] and from_shape.shape[0] % 2 == 0
        sigma_from = 0.0
        sigma_to = 0.0
        cov = np.matrix([[0.0, 0.0], [0.0, 0.0]])
        # compute the mean and cov
        from_shape_points = from_shape.reshape(from_shape.shape[0] // 2, 2)
        to_shape_points = to_shape.reshape(to_shape.shape[0] // 2, 2)
        mean_from = from_shape_points.mean(axis=0)
        mean_to = to_shape_points.mean(axis=0)
        for i in range(from_shape_points.shape[0]):
            temp_dis = np.linalg.norm(from_shape_points[i] - mean_from)
            sigma_from += temp_dis * temp_dis
            temp_dis = np.linalg.norm(to_shape_points[i] - mean_to)
            sigma_to += temp_dis * temp_dis
            cov += (to_shape_points[i].transpose() - mean_to.transpose()) * (from_shape_points[i] - mean_from)
        sigma_from = sigma_from / to_shape_points.shape[0]
        sigma_to = sigma_to / to_shape_points.shape[0]
        cov = cov / to_shape_points.shape[0]
        # compute the affine matrix
        s = np.matrix([[1.0, 0.0], [0.0, 1.0]])
        u, d, vt = np.linalg.svd(cov)
        if np.linalg.det(cov) < 0:
            if d[1] < d[0]:
                s[1, 1] = -1
            else:
                s[0, 0] = -1
        r = u * s * vt
        c = 1.0
        if sigma_from != 0:
            c = 1.0 / sigma_from * np.trace(np.diag(d) * s)
        tran_b = mean_to.transpose() - c * r * mean_from.transpose()
        tran_m = c * r
        return tran_m, tran_b

    def list2col_matrix(self, pts_list):
        """
        列表转为列矩阵
        :param pts_list:点列表e.g[x0,y0,x1,y1,x2,y1],[list]
        :return:
        """
        assert len(pts_list) > 0
        col_mat = []
        for i in range(len(pts_list)):
            col_mat.append(pts_list[i][0])
            col_mat.append(pts_list[i][1])
        col_mat = np.matrix(col_mat).transpose()
        return col_mat

    def calc_dis(self, point1, point2):
        """
        计算两个点的欧式距离
        :param point1:二维坐标e.g.[12.3, 34.1],list
        :param point2:二维坐标e.g.[12.3, 34.1],list
        :return:两个点的欧式距离
        """
        return np.sqrt((point2[1] - point1[1]) ** 2 + (point2[0] - point1[0]) ** 2)

    def reorder_vertexes(self, xy_list):
        """
        对文本线的四个顶点坐标进行重新排序,按照逆时针排序
        :param xy_list: 文本线的四个顶点坐标, [array]
        :return:
        """
        reorder_xy_list = np.zeros_like(xy_list)

        # 确定第一个顶点的坐标,选择横坐标最小的作为第一个顶点
        ordered = np.argsort(xy_list, axis=0)
        xmin1_index = ordered[0, 0]
        xmin2_index = ordered[1, 0]
        if xy_list[xmin1_index, 0] == xy_list[xmin2_index, 0]:
            if xy_list[xmin1_index, 1] <= xy_list[xmin2_index, 1]:
                reorder_xy_list[0] = xy_list[xmin1_index]
                first_v = xmin1_index
            else:
                reorder_xy_list[0] = xy_list[xmin2_index]
                first_v = xmin2_index
        else:
            reorder_xy_list[0] = xy_list[xmin1_index]
            first_v = xmin1_index

        # 计算另外三个顶点与第一个顶点的正切,将值处于中间的顶点作为第三个顶点
        others = list(range(4))
        others.remove(first_v)
        k = np.zeros((len(others),))
        for index, i in zip(others, range(len(others))):
            k[i] = (xy_list[index, 1] - xy_list[first_v, 1]) \
                   / (xy_list[index, 0] - xy_list[first_v, 0] + crnn_config.epsilon)
        k_mid = np.argsort(k)[1]
        third_v = others[k_mid]
        reorder_xy_list[2] = xy_list[third_v]

        # 比较第二个顶点与第四个顶点与第一个顶点的正切与第三个顶点与第一个顶点的正切的大小,
        # 将大于中间值的顶点作为第二个顶点,另一个作为第四个顶点
        others.remove(third_v)
        b_mid = xy_list[first_v, 1] - k[k_mid] * xy_list[first_v, 0]
        second_v, fourth_v = 0, 0
        for index, i in zip(others, range(len(others))):
            # delta = y - (k * x + b)
            delta_y = xy_list[index, 1] - (k[k_mid] * xy_list[index, 0] + b_mid)
            if delta_y > 0:
                second_v = index
            else:
                fourth_v = index
        reorder_xy_list[1] = xy_list[second_v]
        reorder_xy_list[3] = xy_list[fourth_v]

        # 判断是否需要对顶点进行旋转,当第一个顶点是四边形的左下顶点时,则按照逆时针旋转一个单位
        k13 = k[k_mid]
        k24 = (xy_list[second_v, 1] - xy_list[fourth_v, 1]) / (
                xy_list[second_v, 0] - xy_list[fourth_v, 0] + crnn_config.epsilon)
        if k13 < k24:
            tmp_x, tmp_y = reorder_xy_list[3, 0], reorder_xy_list[3, 1]
            for i in range(2, -1, -1):
                reorder_xy_list[i + 1] = reorder_xy_list[i]
            reorder_xy_list[0, 0], reorder_xy_list[0, 1] = tmp_x, tmp_y
        return [reorder_xy_list[1], reorder_xy_list[0], reorder_xy_list[3], reorder_xy_list[2]]


class ImageGenerate(object):
    def __init__(self,
                 img_base_path,
                 font_style_path,
                 text_size_limit,
                 font_size,
                 font_color,
                 train_images_path,
                 train_labels_path,
                 test_images_path,
                 test_labels_path,
                 train_test_ratio,
                 num_samples,
                 dictionary_file,
                 margin=20,
                 write_mode='w',
                 use_blank=False,
                 num_process=1):
        """
        生成类代码图像
        :param img_base_path: 背景文件夹路径,[str]
        :param font_style_path: 字体文件夹路径,包括中英文字体文件夹,[dict]
        :param text_size_limit: 文本字符个数范围列表e.g.[1,8],[list]
        :param font_size: 文本字体大小列表e.g.[24,32,36],[list]
        :param font_color: 文本字体颜色列表e.g.[[0, 0, 0], [255, 36, 36]],[list]
        :param train_images_path: 训练集图片保存路径,[str]
        :param train_labels_path: 训练集标签保存路径,[str]
        :param test_images_path:测试集图片保存路径,[str]
        :param test_labels_path:测试集标签保存路径,[str]
        :param train_test_ratio: 训练集测试集比例,[float]
        :param num_samples: 生成样本总数,[int]
        :param dictionary_file: 字典文件路径,[str]
        :param margin: 文本离背景图的边距
        :param write_mode: 数据写入模式,'w':write,'a':add,[str]
        :param use_blank: 是否使用空格,[boolean]
        :param num_process: 并行生成样本的进程数
        """
        self.img_base_path = img_base_path
        self.font_style_path = font_style_path
        self.text_size_limit = text_size_limit
        self.font_size = font_size
        self.font_color = font_color
        self.train_images_path = train_images_path
        self.train_labels_path = train_labels_path
        self.test_images_path = test_images_path
        self.test_labels_path = test_labels_path
        self.train_test_ratio = train_test_ratio
        self.num_samples = num_samples
        self.dictionary_file = dictionary_file
        assert write_mode in ['w', 'a'], "write mode should be 'w'(write) or 'a'(add)"
        self.write_mode = write_mode
        self.use_blank = use_blank
        self.num_process = num_process
        self.margin = margin
        self.base_image_paths = None
        self.list_words = None
        self.used_ch_word = list()
        self.ch_fonts_list = os.listdir(self.font_style_path['ch'])
        self.en_fonts_list = os.listdir(self.font_style_path['en'])
        super().__init__()

    def generate_image(self, start_end):
        """
        生成样本图片并保存
        :param start_end: 开始ID和结尾ID的list,[list]
        :return:
        """
        # check dir and files
        train_test_offset = start_end[0] + (start_end[1] - start_end[0]) * self.train_test_ratio
        for i in range(start_end[0], start_end[1]):
            # get base image by order
            base_img_path = self.base_image_paths[
                (i - start_end[0]) * len(self.base_image_paths) // (start_end[1] - start_end[0])]

            # choice font_color depend on base image
            if os.path.basename(base_img_path).split('_')[1] == '0':
                font_color = random.choice(self.font_color[3:])
            elif os.path.basename(base_img_path).split('_')[1] == '1':
                font_color = random.choice(self.font_color[0:6] + self.font_color[12:])
            elif os.path.basename(base_img_path).split('_')[1] == '2':
                font_color = random.choice(self.font_color[0:12] + self.font_color[15:])
            elif os.path.basename(base_img_path).split('_')[1] == '3':
                font_color = random.choice(self.font_color[0:16])

            # create image draw
            base_img = Image.open(base_img_path)
            base_img_width, base_img_height = base_img.size
            draw = ImageDraw.Draw(base_img)
            while 1:
                try:
                    # randomly choice font size
                    font_size = random.choice(self.font_size)
                    # randomly choice words str
                    words_str_len = random.randint(self.text_size_limit[0], self.text_size_limit[1])
                    only_latin, words_str = self.get_word_str(words_str_len)
                    # randomly choice font style
                    if only_latin:
                        font_style_path = random.choice(self.en_fonts_list)
                        font_style_path = os.path.join(self.font_style_path['en'], font_style_path)
                    else:
                        font_style_path = random.choice(self.ch_fonts_list)
                        font_style_path = os.path.join(self.font_style_path['ch'], font_style_path)

                    font = ImageFont.truetype(font_style_path, font_size)
                    words_str_width, words_str_height = draw.textsize(words_str, font)
                    x0 = random.randint(self.margin, base_img_width - self.margin - words_str_width)
                    y0 = random.randint(self.margin, base_img_height - self.margin - words_str_height)
                    draw.text((x0, y0), words_str, tuple(font_color), font=font)
                    # save Image
                    x_left = x0 - random.randint(0, self.margin)
                    y_top = y0 - random.randint(0, self.margin)
                    x_right = x0 + words_str_width + random.randint(0, self.margin)
                    y_bottom = y0 + words_str_height + random.randint(0, self.margin)
                    base_img = np.asarray(base_img)[:, :, 0:3]
                    image = base_img[y_top:y_bottom, x_left:x_right]
                    image = Image.fromarray(image)
                    if i < train_test_offset:
                        image_dir = self.train_images_path
                        labels_path = self.train_labels_path
                    else:
                        image_dir = self.test_images_path
                        labels_path = self.test_labels_path
                    image_name = 'img_' + str(i).zfill(len(str(self.num_samples))) + '.jpg'
                    image_save_path = os.path.join(image_dir, image_name)
                    image.save(image_save_path)
                    # save labels
                    with open(labels_path, 'a', encoding='utf-8')as fa:
                        fa.write(image_name + '\t' + words_str + '\n')
                    break
                except Exception as e:
                    continue

    def generate_image_multi_process(self, num_process=None):
        """
        多进程生成样本图片并保存
        :return:
        """
        if num_process is None:
            num_process = self.num_process
        self.base_image_paths = [os.path.join(self.img_base_path, img) for img in
                                 os.listdir(self.img_base_path)]
        words = [Counter(extract_words_i) for extract_words_i in
                 self.extract_words(open(self.dictionary_file, encoding="utf-8").read())]
        self.list_words = [list(words_i.keys()) for words_i in words]
        # check dir and files
        check_path([self.train_images_path,
                    self.train_labels_path,
                    self.test_images_path,
                    self.test_labels_path])
        if self.write_mode == 'w':
            clear_content([self.train_images_path,
                           self.train_labels_path,
                           self.test_images_path,
                           self.test_labels_path])
        data_offset = self.num_samples // num_process
        processes = list()
        for i in trange(0, self.num_samples, data_offset):
            if i + data_offset >= self.num_samples:
                processes.append(Process(target=self.generate_image, args=([i, self.num_samples],)))
            else:
                processes.append(Process(target=self.generate_image, args=([i, i + data_offset],)))
        for process in processes:
            process.start()
        for process in processes:
            process.join()

    def extract_words(self, text):
        """
        提取文字
        :param text:all char about en and ch divided by \n
        :return:word_list,e.g[['1','2',..],['a','b',...,'A','B',...],[',','!',...],['甲','风',...]]
        """
        words_list = text.split('\n')
        words_list = [i.replace(' ', '') for i in words_list]
        words_list = [[j for j in i] for i in words_list]
        if self.use_blank:
            words_list.append([' '])
        return words_list

    def get_word_str(self, length):
        """
        generate word str randomly
        :param length: length of word str
        :return:
        """
        word_str = ''
        self.used_ch_word = list()
        only_latin = False
        # only latin char
        if random.random() < 0.2:
            for i in range(length):
                if self.use_blank and (i == 0 or i == length - 1):
                    words_list_i = random.choice(self.list_words[:3])
                else:
                    if self.use_blank and random.random() < 0.2:
                        words_list_i = random.choice(self.list_words[:3] + self.list_words[-1])
                    else:
                        words_list_i = random.choice(self.list_words[:3])
                word_str += random.choice(words_list_i)
            only_latin = True
        else:
            for i in range(length):
                if self.use_blank and (i == 0 or i == length - 1):
                    words_list_i = random.choice(self.list_words[:-1])
                else:
                    if self.use_blank and random.random() < 0.2:
                        words_list_i = random.choice(self.list_words)
                    else:
                        words_list_i = random.choice(self.list_words[:-1])
                word_str += random.choice(words_list_i)
        return only_latin, word_str


def check_path(path_list):
    """
    检查路径列表中的路径是否存在,如不存在就生存文件夹或者文件
    :param path_list: path list,[list]
    :return:
    """
    for path in path_list:
        if not os.path.exists(path) and '.' not in path[2:]:
            os.mkdir(path)
        elif not os.path.exists(path) and '.' in path[2:]:
            with open(path, 'w', encoding='utf-8') as fw:
                fw.write('')


def clear_content(path_list):
    """
    清空文件夹和文件内容
    :param path_list: path list,[list]
    :return:
    """
    for path in path_list:
        if os.path.isdir(path):
            shutil.rmtree(path)
            os.mkdir(path)
        elif os.path.isfile(path):
            os.remove(path)
            with open(path, 'w', encoding='utf-8') as fw:
                fw.write('')


def do_text_cut(write_mode):
    print("{0}".format('text cutting...').center(100, '='))
    print('train_test_ratio={0}\nfilter_ratio={1}\nfilter_height={2}'
          '\nis_transform={3}\nangle_range={4}\nwrite_mode={5}\nuse_blank={6}\nnum_process={7}'.format(
        crnn_config.train_test_ratio,
        crnn_config.filter_ratio,
        crnn_config.filter_height,
        crnn_config.is_transform,
        crnn_config.angle_range,
        write_mode,
        crnn_config.use_blank,
        crnn_config.num_process))
    print('=' * 100)
    text_cut = TextCut(org_images_path=crnn_config.org_images_path,
                       org_labels_path=crnn_config.org_labels_path,
                       cut_train_images_path=crnn_config.cut_train_images_path,
                       cut_train_labels_path=crnn_config.cut_train_labels_path,
                       cut_test_images_path=crnn_config.cut_test_images_path,
                       cut_test_labels_path=crnn_config.cut_test_labels_path,
                       train_test_ratio=crnn_config.train_test_ratio,
                       filter_ratio=crnn_config.filter_ratio,
                       filter_height=crnn_config.filter_height,
                       is_transform=crnn_config.is_transform,
                       angle_range=crnn_config.angle_range,
                       write_mode=write_mode,
                       use_blank=crnn_config.use_blank,
                       num_process=crnn_config.num_process
                       )
    text_cut.data_load_multi_process()


def do_image_generate(write_mode):
    print("{0}".format('image generating...').center(100, '='))
    print('train_test_ratio={0}\nnum_samples={1}\nmargin={2}\nwrite_mode={3}\nuse_blank={4}\nnum_process={5}'
          .format(crnn_config.train_test_ratio, crnn_config.num_samples, crnn_config.margin, write_mode, crnn_config.use_blank,
                  crnn_config.num_process))
    image_generate = ImageGenerate(img_base_path=crnn_config.base_img_dir,
                                   font_style_path=crnn_config.font_style_path,
                                   text_size_limit=crnn_config.text_size_limit,
                                   font_size=crnn_config.font_size,
                                   font_color=crnn_config.font_color,
                                   train_images_path=crnn_config.train_images_path,
                                   train_labels_path=crnn_config.train_label_path,
                                   test_images_path=crnn_config.test_images_path,
                                   test_labels_path=crnn_config.test_label_path,
                                   train_test_ratio=crnn_config.train_test_ratio,
                                   num_samples=crnn_config.num_samples,
                                   dictionary_file=crnn_config.dictionary_file,
                                   margin=crnn_config.margin,
                                   write_mode=write_mode,
                                   use_blank=crnn_config.use_blank,
                                   num_process=crnn_config.num_process)
    image_generate.generate_image_multi_process()


def do_generate_charset(label_path, charset_path):
    """
    生成字符集文件
    :param label_path: 训练的label地址
    :param charset_path: 字符集文件地址
    :return:
    """
    print("{0}".format('charset generating...').center(100, '='))
    print('label_path={0}\ncharset_path={1}'.format(label_path, charset_path))
    print('=' * 100)
    charset_generate.generate_charset(label_path, charset_path)


if __name__ == '__main__':
    do_text_cut(write_mode='w')
    do_image_generate(write_mode='a')
    # do_generate_charset(crnn_config.train_label_path, crnn_config.charset_path)

    data_gengretor.py则存放的是一些数据的预处理函数,用于训练和测试时调用。

import re
import os
import PIL
import math
import numpy as np
from PIL import Image
from crnn.config import seed
from captcha.image import ImageCaptcha


def get_img_label(label_path, images_path):
    """
    获取图像路径列表和图像标签列表
    :param label_path: 图像路径、标签存放文件对应的路径. [str]
    :param images_path: 图像路径. [str]
    :return:
    """
    with open(label_path, 'r', encoding='utf-8') as f:
        lines = f.read()
    lines = lines.split('\n')
    img_path_list = []
    img_label_list = []
    for line in lines[:-1]:
        this_img_path, this_img_label = line.split('\t')
        this_img_path = os.path.join(images_path, this_img_path)
        img_path_list.append(this_img_path)
        img_label_list.append(this_img_label)
    return img_path_list, img_label_list


def get_charsets(dict=None, mode=1, charset_path=None):
    """
    生成字符集
    :param mode: 当mode=1时,则生成实时验证码进行训练,此时生成验证码的字符集存放在dict路径下的charsets.txt下,
                 当mode=2时,则采用真实场景的图像进行训练,此时会读取data文件夹下label.txt中所有的文本标签,
                 然后汇总去重得到所有的字符集
    :param dict: 字符集文件路径
    :param charset_path: 字符集文件存储路径,only use with mode=2
    :return:
    """
    if mode == 1:
        with open(dict, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        charsets = ''.join(lines)
    else:
        with open(charset_path, 'r', encoding='utf-8') as fr:
            charsets = fr.read()
    charsets = re.sub('\n|\t|', '', charsets)
    charsets = list(set(list(charsets)))
    charsets = sorted(charsets)
    charsets = ''.join(charsets)
    charsets = charsets.encode('utf-8').decode('utf-8')
    return charsets


def gen_random_text(charsets, min_len, max_len):
    """
    生成长度在min_len到max_len的随机文本
    :param charsets: 字符集合. [str]
    :param min_len: 最小文本长度. [int]
    :param max_len: 最长文本长度. [int]
    :return:返回生成的文本编码序列和文本字符串
    """
    length = seed.random_integers(low=min_len, high=max_len)
    idxs = seed.randint(low=0, high=len(charsets), size=length)
    str = ''.join([charsets[i] for i in idxs])
    return idxs, str


def captcha_gen_img(text, image_shape, fonts):
    """
    将文本生成对应的验证码图像
    :param text: 输入的文本. [str]
    :param image_shape: 图像的尺寸. [list]
    :param fonts: 字体文件路径列表. [list]
    :return:
    """
    image = ImageCaptcha(height=image_shape[0], width=image_shape[1], fonts=fonts)
    data = image.generate_image(text)
    data = np.reshape(np.frombuffer(data.tobytes(), dtype=np.uint8), image_shape)
    return data


def captcha_batch_gen(batch_size, charsets, min_len, max_len, image_shape, blank_symbol, fonts):
    """
    生成一个batch验证码数据集,每个batch包含三部分,分别是图像、每张图像的宽度、图像的标签
    :param batch_size: batch_size
    :param charsets: 字符集合
    :param min_len: 最小的文本长度
    :param max_len: 最大的文本长度
    :param image_shape: 生成的图像尺寸
    :param blank_symbol: 当文本长度小于最大的长度时,对其尾部进行padding的数字
    :param fonts: 字体文件路径列表
    :return:
    """
    batch_labels = []
    batch_images = []
    batch_image_widths = []

    for _ in range(batch_size):
        idxs, text = gen_random_text(charsets, min_len, max_len)
        image = captcha_gen_img(text, image_shape, fonts)
        image = image / 255

        pad_size = max_len - len(idxs)
        if pad_size > 0:
            idxs = np.pad(idxs, pad_width=(0, pad_size), mode='constant', constant_values=blank_symbol)
        batch_image_widths.append(image.shape[1])
        batch_labels.append(idxs)
        batch_images.append(image)

    batch_labels = np.array(batch_labels, dtype=np.int32)
    batch_images = np.array(batch_images, dtype=np.float32)
    batch_image_widths = np.array(batch_image_widths, dtype=np.int32)

    return batch_images, batch_image_widths, batch_labels


def scence_batch_gen(batch_img_list, batch_img_label_list,
                     charsets, image_shape, max_len, blank_symbol):
    """
    生成一个batch真实场景数据集,每个batch包含三部分,分别是图像、每张图像的宽度、图像的标签
    :param batch_img_list: 图像路径列表
    :param batch_img_label_list: 图像标签列表
    :param charsets: 字符集字符串
    :param image_shape: 生成的图像尺寸
    :param max_len: 文本序列的最大长度
    :param blank_symbol: 当文本长度小于最大的长度时,对其尾部进行padding的数字
    :return:
    """
    batch_labels = []
    batch_image_widths = []
    batch_size = len(batch_img_label_list)
    batch_images = np.zeros(shape=(batch_size, image_shape[0], image_shape[1], image_shape[2]), dtype=np.float32)

    for i, path, label in zip(range(batch_size), batch_img_list, batch_img_label_list):
        # 对图像进行放缩
        image = Image.open(path)
        img_size = image.size
        height_ratio = image_shape[0] / img_size[1]
        if int(img_size[0] * height_ratio) > image_shape[1]:
            new_img_size = (image_shape[1], image_shape[0])
            image = image.resize(new_img_size, Image.ANTIALIAS).convert('RGB')
            image = np.array(image, np.float32)
            image = image / 255
            batch_images[i, :, :, :] = image
        else:
            new_img_size = (int(img_size[0] * height_ratio), image_shape[0])
            image = image.resize(new_img_size, Image.ANTIALIAS).convert('RGB')
            image = np.array(image, np.float32)
            image = image / 255
            batch_images[i, :image.shape[0], :image.shape[1], :] = image

        # 对标签进行编码
        if len(label) > max_len:
            label = label[:max_len]
        idxs = [charsets.index(i) for i in label]

        # 对标签进行padding
        pad_size = max_len - len(idxs)
        if pad_size > 0:
            idxs = np.pad(idxs, pad_width=(0, pad_size), mode='constant', constant_values=blank_symbol)

        batch_image_widths.append(image_shape[1])
        batch_labels.append(idxs)

    batch_labels = np.array(batch_labels, dtype=np.int32)
    batch_image_widths = np.array(batch_image_widths, dtype=np.int32)

    return batch_images, batch_image_widths, batch_labels


def load_images(batch_img_list, image_shape):
    """
    生成一个batch真实场景数据集,每个batch包含三部分,分别是图像、每张图像的宽度、图像的标签
    :param batch_img_list: 图像路径列表或图像列表[list]
    :param image_shape: 生成的图像尺寸
    :return:
    """
    # 参数为图像路径列表
    if isinstance(batch_img_list[0], str):
        batch_size = len(batch_img_list)
        batch_image_widths = []
        batch_images = np.zeros(shape=(batch_size, image_shape[0], image_shape[1], image_shape[2]), dtype=np.float32)

        for i, path in zip(range(batch_size), batch_img_list):
            # 对图像进行放缩
            image = Image.open(path)
            img_size = image.size
            height_ratio = image_shape[0] / img_size[1]
            if int(img_size[0] * height_ratio) > image_shape[1]:
                new_img_size = (image_shape[1], image_shape[0])
                image = image.resize(new_img_size, Image.ANTIALIAS).convert('RGB')
                image = np.array(image, np.float32)
                image = image / 255
                batch_images[i, :, :, :] = image
            else:
                new_img_size = (int(img_size[0] * height_ratio), image_shape[0])
                image = image.resize(new_img_size, Image.ANTIALIAS).convert('RGB')
                image = np.array(image, np.float32)
                image = image / 255
                batch_images[i, :image.shape[0], :image.shape[1], :] = image
            batch_image_widths.append(image_shape[1])
    # 参数为图像列表
    elif isinstance(batch_img_list[0], PIL.Image.Image):
        batch_size = len(batch_img_list)
        batch_image_widths = []
        batch_images = np.zeros(shape=(batch_size, image_shape[0], image_shape[1], image_shape[2]), dtype=np.float32)

        for i in range(batch_size):
            # 对图像进行放缩
            image = batch_img_list[i]
            img_size = image.size
            height_ratio = image_shape[0] / img_size[1]
            if int(img_size[0] * height_ratio) > image_shape[1]:
                new_img_size = (image_shape[1], image_shape[0])
                image = image.resize(new_img_size, Image.ANTIALIAS).convert('RGB')
                image = np.array(image, np.float32)
                image = image / 255
                batch_images[i, :, :, :] = image
            else:
                new_img_size = (int(img_size[0] * height_ratio), image_shape[0])
                image = image.resize(new_img_size, Image.ANTIALIAS).convert('RGB')
                image = np.array(image, np.float32)
                image = image / 255
                batch_images[i, :image.shape[0], :image.shape[1], :] = image
            batch_image_widths.append(image_shape[1])

    return batch_images, batch_image_widths

     最后是模型的类文件,主要是定义模型的结构和损失函数以及训练函数,其代码如下:

import os
import random
import numpy as np
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.rnn import BasicLSTMCell
from crnn.data_generator import get_charsets, captcha_batch_gen, scence_batch_gen, get_img_label


class CRNN(object):
    def __init__(self,
                 image_shape,
                 min_len,
                 max_len,
                 lstm_hidden,
                 pool_size,
                 learning_decay_rate,
                 learning_rate,
                 learning_decay_steps,
                 mode,
                 dict,
                 is_training,
                 train_label_path,
                 train_images_path,
                 charset_path):
        self.min_len = min_len
        self.max_len = max_len
        self.lstm_hidden = lstm_hidden
        self.pool_size = pool_size
        self.learning_decay_rate = learning_decay_rate
        self.learning_rate = learning_rate
        self.learning_decay_steps = learning_decay_steps
        self.mode = mode
        self.dict = dict
        self.is_training = is_training
        self.train_label_path = train_label_path
        self.train_images_path = train_images_path
        self.charset_path = charset_path
        self.charsets = get_charsets(self.dict, self.mode, self.charset_path)
        self.image_shape = image_shape
        self.images = tf.placeholder(dtype=tf.float32,
                                     shape=[None, self.image_shape[0], self.image_shape[1], self.image_shape[2]])
        self.image_widths = tf.placeholder(dtype=tf.int32, shape=[None])
        self.labels = tf.placeholder(dtype=tf.int32, shape=[None, self.max_len])
        self.seq_len_inputs = tf.divide(self.image_widths, self.pool_size, name='seq_len_input_op') - 1
        self.logprob = self.forward(self.is_training)
        self.train_op, self.loss_ctc = self.create_train_op(self.logprob)
        self.dense_predicts = self.decode_predict(self.logprob)

    def vgg_net(self, inputs, is_training, scope='vgg'):
        batch_norm_params = {
            'is_training': is_training
        }
        with tf.variable_scope(scope):
            with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params):
                with slim.arg_scope([slim.max_pool2d], padding='SAME'):
                    with slim.arg_scope([slim.batch_norm], **batch_norm_params):
                        net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1')
                        net = slim.max_pool2d(net, [2, 2], scope='pool1')
                        net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2')
                        net = slim.max_pool2d(net, [2, 2], scope='pool2')
                        net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3')
                        net = slim.max_pool2d(net, [2, 2], stride=[2, 1], scope='pool3')
                        net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4')
                        net = slim.max_pool2d(net, [2, 2], stride=[2, 1], scope='pool4')
                        net = slim.repeat(net, 1, slim.conv2d, 512, [3, 3], scope='conv5')
                        return net

    def forward(self, is_training):
        dropout_keep_prob = 0.7 if is_training else 1.0
        cnn_net = self.vgg_net(self.images, is_training)

        with tf.variable_scope('Reshaping_cnn'):
            shape = cnn_net.get_shape().as_list()  # [batch, height, width, features]
            transposed = tf.transpose(cnn_net, perm=[0, 2, 1, 3],
                                      name='transposed')  # [batch, width, height, features]
            conv_reshaped = tf.reshape(transposed, [-1, shape[2], shape[1] * shape[3]],
                                       name='reshaped')  # [batch, width, height x features]

        list_n_hidden = [self.lstm_hidden, self.lstm_hidden]

        with tf.name_scope('deep_bidirectional_lstm'):
            # Forward direction cells
            fw_cell_list = [BasicLSTMCell(nh, forget_bias=1.0) for nh in list_n_hidden]
            # Backward direction cells
            bw_cell_list = [BasicLSTMCell(nh, forget_bias=1.0) for nh in list_n_hidden]

            lstm_net, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(fw_cell_list,
                                                                            bw_cell_list,
                                                                            conv_reshaped,
                                                                            dtype=tf.float32
                                                                            )
            # Dropout layer
            lstm_net = tf.nn.dropout(lstm_net, keep_prob=dropout_keep_prob)

        with tf.variable_scope('fully_connected'):
            shape = lstm_net.get_shape().as_list()  # [batch, width, 2*n_hidden]
            fc_out = slim.layers.linear(lstm_net, len(self.charsets) + 1)  # [batch x width, n_class]

            lstm_out = tf.reshape(fc_out, [-1, shape[1], len(self.charsets) + 1],
                                  name='lstm_out')  # [batch, width, n_classes]

            # Swap batch and time axis
            logprob = tf.transpose(lstm_out, [1, 0, 2], name='transpose_time_major')  # [width(time), batch, n_classes]

        return logprob

    def create_loss(self, logprob):
        sparse_code_target = self.dense_to_sparse(self.labels, blank_symbol=len(self.charsets) + 1)
        with tf.control_dependencies(
                [tf.less_equal(sparse_code_target.dense_shape[1],
                               tf.reduce_max(tf.cast(self.seq_len_inputs, tf.int64)))]):
            loss_ctc = tf.nn.ctc_loss(labels=sparse_code_target,
                                      inputs=logprob,
                                      sequence_length=tf.cast(self.seq_len_inputs, tf.int32),
                                      preprocess_collapse_repeated=False,
                                      ctc_merge_repeated=True,
                                      ignore_longer_outputs_than_inputs=True,
                                      # returns zero gradient in case it happens -> ema loss = NaN
                                      time_major=True)
            loss_ctc = tf.reduce_mean(loss_ctc)
        return loss_ctc

    def create_train_op(self, logprob):
        loss_ctc = self.create_loss(logprob)
        tf.losses.add_loss(loss_ctc)

        self.global_step = tf.train.get_or_create_global_step()

        learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step,
                                                   self.learning_decay_steps, self.learning_decay_rate,
                                                   staircase=True)

        optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        train_op = slim.learning.create_train_op(total_loss=tf.losses.get_total_loss(), optimizer=optimizer,
                                                 update_ops=update_ops)
        return train_op, loss_ctc

    def decode_predict(self, logprob):
        with tf.name_scope('decode_conversion'):
            sparse_code_pred, log_probability = tf.nn.ctc_greedy_decoder(logprob,
                                                                         sequence_length=tf.cast(
                                                                             self.seq_len_inputs,
                                                                             tf.int32
                                                                         ))
            sparse_code_pred = sparse_code_pred[0]
            dense_predicts = tf.sparse_to_dense(sparse_code_pred.indices,
                                                sparse_code_pred.dense_shape,
                                                sparse_code_pred.values, default_value=-1)

        return dense_predicts

    def dense_to_sparse(self, dense_tensor, blank_symbol):
        """
        将标签转化为稀疏表示
        :param dense_tensor: 原始的密集标签
        :param blank_symbol: padding的符号
        :return:
        """
        indices = tf.where(tf.not_equal(dense_tensor, blank_symbol))
        values = tf.gather_nd(dense_tensor, indices)
        sparse_target = tf.SparseTensor(indices, values, [-1, self.image_shape[1]])
        return sparse_target

    def train(self,
              epoch=100,
              batch_size=32,
              train_images_path=None,
              train_label_path=None,
              restore=False,
              fonts=None,
              logs_path=None,
              models_path=None,
              ):
        # 创建相关目录
        if not os.path.exists(models_path):
            os.mkdir(models_path)
        if not os.path.exists(logs_path):
            os.mkdir(logs_path)

        # summary
        tf.summary.scalar('loss_ctc', self.loss_ctc)
        merged = tf.summary.merge_all()

        # sess and writer
        sess = tf.Session()
        writer = tf.summary.FileWriter(logs_path, sess.graph)
        saver = tf.train.Saver(max_to_keep=10)
        sess.run(tf.global_variables_initializer())

        # restore model
        last_epoch = 0
        if restore:
            ckpt = tf.train.latest_checkpoint(models_path)
            if ckpt:
                last_epoch = int(ckpt.split('-')[1]) + 1
                saver.restore(sess, ckpt)

        # 计算batch的数量
        if self.mode == 1:
            batch_nums = 1000
        else:
            train_img_list, train_label_list = get_img_label(train_label_path, train_images_path)
            batch_nums = int(np.ceil(len(train_img_list) / batch_size))

        if self.mode == 1:
            for i in range(last_epoch, epoch):
                for j in range(batch_nums):
                    batch_images, batch_image_widths, batch_labels = captcha_batch_gen(
                        batch_size,
                        self.charsets,
                        self.min_len,
                        self.max_len,
                        self.image_shape,
                        len(self.charsets) + 1,
                        fonts
                    )
                    _, loss, predict_label = sess.run(
                        [self.train_op, self.loss_ctc, self.dense_predicts],
                        feed_dict={self.images: batch_images,
                                   self.image_widths: batch_image_widths,
                                   self.labels: batch_labels}
                    )
                    if j % 1 == 0:
                        print('epoch:%d/%d, batch:%d/%d, loss:%.4f, truth:%s, predict:%s' % (
                            i, epoch,
                            j, batch_nums,
                            loss,
                            ''.join([self.charsets[k] for k in batch_labels[0] if k != (len(self.charsets) + 1)]),
                            ''.join([self.charsets[v] for v in predict_label[0] if v != -1])
                        ))

                saver.save(sess, save_path=models_path, global_step=i)
                summary = sess.run(merged,
                                   feed_dict={
                                       self.images: batch_images,
                                       self.image_widths: batch_image_widths,
                                       self.labels: batch_labels
                                   })
                writer.add_summary(summary, global_step=i)
        else:
            for i in range(last_epoch, epoch):
                random_index = random.sample(range(len(train_img_list)), len(train_img_list))
                batch_index = np.array_split(np.array(random_index), batch_nums)
                for j in range(batch_nums):
                    this_batch_index = list(batch_index[j])
                    this_train_img_list = [train_img_list[index] for index in this_batch_index]
                    this_train_label_list = [train_label_list[index] for index in this_batch_index]
                    batch_images, batch_image_widths, batch_labels = scence_batch_gen(
                        this_train_img_list,
                        this_train_label_list,
                        self.charsets,
                        self.image_shape,
                        self.max_len,
                        len(self.charsets) + 1
                    )
                    _, loss, predict_label = sess.run(
                        [self.train_op, self.loss_ctc, self.dense_predicts],
                        feed_dict={self.images: batch_images,
                                   self.image_widths: batch_image_widths,
                                   self.labels: batch_labels}
                    )
                    if j % 1 == 0:
                        print('epoch:%d/%d, batch:%d/%d, loss:%.4f, truth:%s, predict:%s' % (
                            i, epoch,
                            j, batch_nums,
                            loss,
                            ''.join([self.charsets[i] for i in batch_labels[0] if i != (len(self.charsets) + 1)]),
                            ''.join([self.charsets[v] for v in predict_label[0] if v != -1])
                        ))

                saver.save(sess, save_path=models_path, global_step=i)
                summary = sess.run(merged,
                                   feed_dict={
                                       self.images: batch_images,
                                       self.image_widths: batch_image_widths,
                                       self.labels: batch_labels
                                   })
                writer.add_summary(summary, global_step=i)

    模型定义结束后,就可以开始训练了,下面是训练的脚本,直接调用模型的类即可。

import os
import tensorflow as tf
from crnn.modules import CRNN
from crnn import config as crnn_config
os.environ["CUDA_VISIBLE_DEVICES"] = "2"


def main(_):
    crnn = CRNN(image_shape=crnn_config.image_shape,
                min_len=crnn_config.min_len,
                max_len=crnn_config.max_len,
                lstm_hidden=crnn_config.lstm_hidden,
                pool_size=crnn_config.pool_size,
                learning_decay_rate=crnn_config.learning_decay_rate,
                learning_rate=crnn_config.learning_rate,
                learning_decay_steps=crnn_config.learning_decay_steps,
                mode=crnn_config.mode,
                dict=crnn_config.dict,
                is_training=True,
                train_label_path=crnn_config.train_label_path,
                train_images_path=crnn_config.train_images_path,
                charset_path=crnn_config.charset_path)
    crnn.train(epoch=crnn_config.epoch,
               batch_size=crnn_config.batch_size,
               train_images_path=crnn_config.train_images_path,
               train_label_path=crnn_config.train_label_path,
               restore=True,
               fonts=crnn_config.fonts,
               logs_path=crnn_config.logs_path,
               models_path=crnn_config.models_path)


if __name__ == '__main__':
    tf.app.run()

    训练结束后,模型会存放在model路径下,直接执行predict.py脚本可以对测试集路径下的脚本进行预测。

# -*- utf-8 -*-
"""
    @describe: text recognition with images path or images ndarray list
    @author: xushen
    @date: 2018-12-25
"""
import os
import tensorflow as tf
from crnn.modules import CRNN
from multiprocessing import Pool
from crnn import config as crnn_config
from crnn.data_generator import load_images

crnn_graph = tf.Graph()
with crnn_graph.as_default():
    crnn = CRNN(image_shape=crnn_config.image_shape,
                min_len=crnn_config.min_len,
                max_len=crnn_config.max_len,
                lstm_hidden=crnn_config.lstm_hidden,
                pool_size=crnn_config.pool_size,
                learning_decay_rate=crnn_config.learning_decay_rate,
                learning_rate=crnn_config.learning_rate,
                learning_decay_steps=crnn_config.learning_decay_steps,
                mode=crnn_config.mode,
                dict=crnn_config.dict,
                is_training=False,
                train_label_path=crnn_config.train_label_path,
                train_images_path=crnn_config.train_images_path,
                charset_path=crnn_config.charset_path)

crnn_sess = tf.Session(graph=crnn_graph)
with crnn_sess.as_default():
    with crnn_graph.as_default():
        tf.global_variables_initializer().run()
        crnn_saver = tf.train.Saver(tf.global_variables())
        crnn_ckpt = tf.train.get_checkpoint_state(crnn_config.models_path)
        crnn_saver.restore(crnn_sess, crnn_ckpt.model_checkpoint_path)


def predict(images, batch_size=crnn_config.predict_batch_size):
    """
    predict images
    :param images:images path or list of images ,[list/str]
    :param batch_size: batch size
    :return:
    """
    if isinstance(images, str):
        assert os.path.exists(images), 'path of image or images dir is not exist'
        if os.path.isdir(images):
            test_img_list = os.listdir(images)
            batch_size = len(test_img_list) if len(test_img_list) <= batch_size else batch_size
            test_img_list = [os.path.join(images, i) for i in test_img_list]
            batch_images, batch_image_widths = load_images(
                test_img_list,
                crnn.image_shape
            )
        elif os.path.isfile(images):
            test_img_list = [images]
            batch_size = len(test_img_list) if len(test_img_list) <= batch_size else batch_size
            batch_images, batch_image_widths = load_images(
                test_img_list,
                crnn.image_shape
            )

    elif isinstance(images, list):
        assert len(images) > 0, '图片数量不可以为0'
        batch_size = len(images) if len(images) <= batch_size else batch_size
        batch_images, batch_image_widths = load_images(
            images,
            crnn.image_shape
        )
    # 启用多线程
    predict_label_list = list()
    for i in range(0, len(batch_images), batch_size):
        if i + batch_size >= len(batch_images):
            batch_size = len(batch_images) - i
        predict_label_list.append(crnn_sess.run(crnn.dense_predicts,
                                                feed_dict={crnn.images: batch_images[i:i + batch_size],
                                                           crnn.image_widths: batch_image_widths[i:i + batch_size]}))
    result = list()
    for predict_label in predict_label_list:
        for j in range(len(predict_label)):
            text_i = ''.join([crnn.charsets[v] for v in predict_label[j] if v != -1])
            if text_i.replace(' ', '') != '':
                result.append(text_i)
    return result


if __name__ == '__main__':
    # 可以传入本地图片文件夹路径、本地图片路径、ndarray图片列表
    predict(crnn_config.predict_images_path, crnn_config.predict_batch_size)

     config.py存放的是各个超参数的定义,其中主要注意的是mode参数的设置,当设置为1时则直接模拟验证码数据集进行训练,当设置为2时则需要提供真实数据集进行训练。

import time
import numpy as np

# data
mode = 2  # mode=1则用验证码进行训练,mode=2则用真实场景进行训练
image_shape = [32, 1024, 3]  # 图像尺寸
seed = np.random.RandomState(int(round(time.time())))  # 生成模拟数据时的随机种子
min_len = 1  # 文本的最小长度
max_len = 256  # 文本的最大长度
fonts = ['./crnn/fonts/ch_font/STSONG.TTF']  # 生成模拟数据时的字体文件路径列表
train_images_path = './crnn/data/train_images'  # 训练集图像存放路径
train_label_path = './crnn/data/train_label.txt'  # 训练集标签存放路径
test_images_path = './crnn/data/test_images'  # 测试集图像存放路径
test_label_path = './crnn/data/test_label.txt'  # 测试集标签存放路径
dict = './crnn/dict/english.txt'
logs_path = './crnn/logs'  # 训练日志存放路径
models_path = './crnn/models'  # 模型存放路径

# data icpr
org_images_path = './crnn/data/origin_images'  # ICPR数据集原始图像路径
org_labels_path = './crnn/data/txt'  # ICPR数据集原始label路径
cut_train_images_path = './crnn/data/train_images'  # 训练集切图的保存路径
cut_train_labels_path = './crnn/data/train_label.txt'  # 训练集切图对应label的保存路径
cut_test_images_path = './crnn/data/test_images'  # 测试集切图的保存路径
cut_test_labels_path = './crnn/data/test_label.txt'  # 测试集切图对应label的保存路径
train_test_ratio = 0.9  # 训练测试集的比例
is_transform = True  # 是否进行仿射变换
angle_range = [-15.0, 15.0]  # 不进行仿射变换的倾斜角度范围
epsilon = 1e-4  # 原始图像的顺时针变换参数
filter_ratio = 1.3  # 图片过滤的高宽比例,高于该比例的图片将被过滤
filter_height = 16  # 高度过滤,切图后的图像高度低于该值的将被过滤掉,[int]

# data generate (with base images)
num_samples = 100  # 生成样本总量
base_img_dir = './crnn/images_base'  # 背景图文件夹路径
font_style_path = {'ch': './crnn/fonts/ch_fonts', 'en': './crnn/fonts/en_fonts'}  # 字体文件夹路径
font_size = [12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 40]  # 字体大小列表
# 字体颜色列表 ,black:0-3 gray:3-6 blue:6-12 green:12-15 brown:15-16 white:16-17
font_color = [[0, 0, 0], [36, 36, 36], [83, 72, 53], [109, 129, 139], [139, 139, 139], [143, 161, 143],
              [106, 160, 194], [97, 174, 238], [191, 234, 255], [118, 103, 221], [198, 120, 221], [64, 148, 216],
              [147, 178, 139], [76, 136, 107], [62, 144, 135], [209, 125, 72], [255, 255, 255]]
dictionary_file = './crnn/dict/en_ch.txt'  # 字典文件路径
text_size_limit = [1, 256]  # 生成文本字符范围
margin = 10  # 生成文本离背景图的边距最大值
use_blank = True  # 是否使用多线程,默认False
num_process = 1  # 并行处理数据的进程数,默认1(即单进程)

# charset generate
charset_path = './crnn/data/charset.txt'

# model
lstm_hidden = 256

# train
pool_size = 2 * 2  # pool层总共对图像宽度的缩小倍数
batch_size = 32  # batch_size
learning_rate = 1e-3  # 学习率
learning_decay_steps = 3000  # 学习率每多少次递减一次
learning_decay_rate = 0.95  # 学习率每次递减时,变为原来的多少
epoch = 100  # 迭代的次数

# predict
predict_batch_size = 64
predict_images_path = './crnn/data/predict_images'
predict_label_path = './crnn/data/predict_label.txt'

    本文在ICPR数据集训练了大概19个epoch后,模型已经基本达到稳定,其效果如下:

    在验证码的数据集上训练后,文本识别的效果也基本达到了100%的准确率,效果如下:

4.总结

在文章的最后,大概总结一下CRNN模型的优点吧:

  • CRNN可以对文本进行端到端的识别
  • 可以对任意长度的文本序列进行识别,而不需要涉及到字符分割、水平尺度归一化等技术
  • 它不局限于任何预定义的词典,在无词典和基于词典的场景文本识别任务中都取得了显著的性能。
  • 模型更加轻量级
  • 29
    点赞
  • 121
    收藏
    觉得还不错? 一键收藏
  • 105
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 105
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值