【OCR技术系列之六】文本检测CTPN的代码实现

这几天一直在用Pytorch来复现文本检测领域的CTPN论文,本文章将从数据处理、训练标签生成、神经网络搭建、损失函数设计、训练主过程编写等这几个方面来一步一步复现CTPN。CTPN算法理论可以参考这里

训练数据处理

我们的训练选择天池ICPR2018和MSRA_TD500两个数据集,天池ICPR的数据集为网络图像,都是一些淘宝商家上传到淘宝的一些商品介绍图像,其标签方式参考了ICDAR2015的数据标签格式,即一个文本框用4个坐标来表示,即左上、右上、右下、左下四个坐标,共八个值,记作[x1 y1 x2 y2 x3 y3 x4 y4]

1093303-20181202175424984-936238952.png

天池ICPR2018数据集的风格如下,字体形态格式颜色多变,多嵌套于物体之中,识别难度大:

1093303-20181202181734519-193227309.png

MSRA_TD500使微软收集的一个文本检测和识别的一个数据集,里面的图像多是街景图,背景比较复杂,但文本位置比较明显,一目了然。因为MSRA_TD500的标签格式不一样,最后一个参数表示矩形框的旋转角度。

1093303-20181202175515910-1923330407.png

所以我们第一步就是将这两个数据集的标签格式统一,我的做法是将MSRA数据集格式改为ICDAR格式,方便后面的模型训练。因为MSRA_TD500采取的标签格式是[index difficulty_label x y w h angle],所以我们需要根据这个文本框的旋转角度来求得水平文本框旋转后的4个坐标位置。实现如下:

"""
This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format.

MSRA_TD500 format: [index difficulty_label x y w h angle]

ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y]

"""


import math
import cv2
import os

# 求旋转后矩形的4个坐标
def get_box_img(x, y, w, h, angle):
    # 矩形框中点(x0,y0)
    x0 = x + w/2
    y0 = y + h/2
    l = math.sqrt(pow(w/2, 2) + pow(h/2, 2))  # 即对角线的一半
    # angle小于0,逆时针转
    if angle < 0:
        a1 = -angle + math.atan(h / float(w))  # 旋转角度-对角线与底线所成的角度
        a2 = -angle - math.atan(h / float(w)) # 旋转角度+对角线与底线所成的角度
        pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2))
        pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1))
        pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2))  # x0+左下点旋转后在水平线上的投影, y0-左下点在垂直线上的投影,显然逆时针转时,左下点上一和左移了。
        pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1))
    else:
        a1 = angle + math.atan(h / float(w))
        a2 = angle - math.atan(h / float(w))
        pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1))
        pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2))
        pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1))
        pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2))
    return [pt1[0], pt1[1], pt2[0], pt2[1], pt3[0], pt3[1], pt4[0], pt4[1]]


def read_file(path):
    result = []
    for line in open(path):
        info = []
        data = line.split(' ')
        info.append(int(data[2]))
        info.append(int(data[3]))
        info.append(int(data[4]))
        info.append(int(data[5]))
        info.append(float(data[6]))
        info.append(data[0])
        result.append(info)
    return result


if __name__ == '__main__':
    file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/'
    save_img_path = '../dataset/OCR_dataset/ctpn/test_im/'
    save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/'
    file_list = os.listdir(file_path)
    for f in file_list:
        if '.gt' in f:
            continue
        name = f[0:8]
        txt_path = file_path + name + '.gt'
        im_path = file_path + f
        im = cv2.imread(im_path)
        coordinate = read_file(txt_path)
        # 仿照ICDAR格式,图片名字写做img_xx.jpg,对应的标签文件写做gt_img_xx.txt
        cv2.imwrite(save_img_path + name.lower() + '.jpg', im)
        save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w')
        for i in coordinate:
            box = get_box_img(i[0], i[1], i[2], i[3], i[4])
            box = [int(box[i]) for i in range(len(box))]
            box = [str(box[i]) for i in range(len(box))]
            save_gt.write(','.join(box))
            save_gt.write('\n')

经过格式处理后,我们两份数据集算是整理好了。当然我们还需要对整个数据集划分为训练集和测试集,我的文件组织习惯如下:train_im, test_im文件夹装的是训练和测试图像,train_gt和test_gt装的是训练和测试标签。

1093303-20181202175530276-1351005026.png

训练标签生成

因为CTPN的核心思想也是基于Faster RCNN中的region proposal机制的,所以原始数据标签需要转化为
anchor标签。训练数据的标签的生成的代码是最难写,因为从一个完整的文本框标签转化为一个个小尺度文本框标签确实有点难度,而且这个anchor标签的生成方式也与Faster RCNN生成方式略有不同。下面讲一讲我的实现思路:

第一步我们需要将原先每张图的bbox标签转化为每个anchor标签。为了实现该功能,我们先将一张图划分为宽度为16的各个anchor。

  • 首先计算一张图可以分为多少个宽度为16的acnhor(比如一张图的宽度为w,那么水平anchor总数为w/16),再计算出我们的文本框标签中含有几个acnhor,最左和最右的anchor又是哪几个;
  • 计算文本框内anchor的高度和中心是多少:此时我们可以在一个全黑的mask中把文本框label画上去(白色),然后从上往下和从下往上找到第一个白色像素点的位置作为该anchor的上下边界;
  • 最后将每个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回
def generate_gt_anchor(img, box, anchor_width=16):
    """
    calsulate ground truth fine-scale box
    :param img: input image
    :param box: ground truth box (4 point)
    :param anchor_width:
    :return: tuple (position, h, cy)
    """
    if not isinstance(box[0], float):
        box = [float(box[i]) for i in range(len(box))]
    result = []
    # 求解一个bbox下,能分解为多少个16宽度的小anchor,并求出最左和最右的小achor的id
    left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width))  # the left side anchor of the text box, downwards
    right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width))  # the right side anchor of the text box, upwards
    
    # handle extreme case, the right side anchor may exceed the image width
    if right_anchor_num * 16 + 15 > img.shape[1]:
        right_anchor_num -= 1
        
    # combine the left-side and the right-side x_coordinate of a text anchor into one pair
    position_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)]
    
    # 计算每个gt anchor的真实位置,其实就是求解gt anchor的上边界和下边界
    y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box)
    # 最后将每个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回
    for i in range(len(position_pair)):
        position = int(position_pair[i][0] / anchor_width)  # the index of anchor box
        h = y_bottom[i] - y_top[i] + 1  # the height of anchor box
        cy = (float(y_bottom[i]) + float(y_top[i])) / 2.0  # the center point of anchor box
        result.append((position, cy, h))
    return result

计算anchor上下边界的方法:

# cal the gt anchor box's bottom and top coordinate
def cal_y_top_and_bottom(raw_img, position_pair, box):
    """
    :param raw_img:
    :param position_pair: for example:[(0, 15), (16, 31), ...]
    :param box: gt box (4 point)
    :return: top and bottom coordinates for y-axis
    """
    img = copy.deepcopy(raw_img)
    y_top = []
    y_bottom = []
    height = img.shape[0]
    # 设置图像mask,channel 0为全黑图
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            img[i, j, 0] = 0
            
    top_flag = False
    bottom_flag = False
    # 根据bbox四点画出文本框,channel 0下文本框为白色
    img = other.draw_box_4pt(img, box, color=(255, 0, 0))
    
    
    for k in range(len(position_pair)):
        # 从左到右遍历anchor gt,对每个anchor从上往下扫描像素,遇到白色像素点(255)就停下来,此时像素点坐标y就是该anchor gt的上边界
        # calc top y coordinate
        for y in range(0, height-1):
            # loop each anchor, from left to right
        
  • 4
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
OCR技术是一种能够将图像中的文本内容转化为可编辑文本技术,其中ctpn和crnn是OCR技术中的两个重要组成部分。 ctpn(Connectionist Text Proposal Network)是一种基于深度学习的文本检测算法,其主要任务是检测图像中的文本行和单个字符,并将其转换为一组矩形边界框。这些边界框可以用于后续的文本识别操作。 crnn(Convolutional Recurrent Neural Network)是一种基于深度学习的文本识别算法,其主要任务是根据文本检测阶段生成的文本行或单个字符图像,识别其中的文本内容。crnn算法通常由卷积神经网络(CNN)和循环神经网络(RNN)两个部分组成,其中CNN用于提取图像特征,RNN用于对特征序列进行建模。 以下是一个基于ctpn和crnn的OCR代码实现示例(Python): ```python import cv2 import numpy as np import tensorflow as tf # 加载ctpn模型 ctpn_model = cv2.dnn.readNet('ctpn.pb') # 加载crnn模型 crnn_model = tf.keras.models.load_model('crnn.h5') # 定义字符集 charset = '0123456789abcdefghijklmnopqrstuvwxyz' # 定义字符到索引的映射表 char_to_index = {char: index for index, char in enumerate(charset)} # 定义CTPN参数 ctpn_params = { 'model': 'ctpn', 'scale': 600, 'max_scale': 1200, 'text_proposals': 2000, 'min_size': 16, 'line_min_score': 0.9, 'text_proposal_min_score': 0.7, 'text_proposal_nms_threshold': 0.3, 'min_num_proposals': 2, 'max_num_proposals': 10 } # 定义CRNN参数 crnn_params = { 'model': 'crnn', 'img_w': 100, 'img_h': 32, 'num_classes': len(charset), 'rnn_units': 128, 'rnn_dropout': 0.25, 'rnn_recurrent_dropout': 0.25, 'rnn_activation': 'relu', 'rnn_type': 'lstm', 'rnn_direction': 'bidirectional', 'rnn_merge_mode': 'concat', 'cnn_filters': 32, 'cnn_kernel_size': (3, 3), 'cnn_activation': 'relu', 'cnn_pool_size': (2, 2) } # 定义文本检测函数 def detect_text(image): # 将图像转换为灰度图 gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 缩放图像 scale = ctpn_params['scale'] max_scale = ctpn_params['max_scale'] if np.max(gray) > 1: gray = gray / 255 rows, cols = gray.shape if rows > max_scale: scale = max_scale / rows gray = cv2.resize(gray, (0, 0), fx=scale, fy=scale) rows, cols = gray.shape elif rows < scale: scale = scale / rows gray = cv2.resize(gray, (0, 0), fx=scale, fy=scale) rows, cols = gray.shape # 文本检测 ctpn_model.setInput(cv2.dnn.blobFromImage(gray)) output = ctpn_model.forward() boxes = [] for i in range(output.shape[2]): score = output[0, 0, i, 2] if score > ctpn_params['text_proposal_min_score']: x1 = int(output[0, 0, i, 3] * cols / scale) y1 = int(output[0, 0, i, 4] * rows / scale) x2 = int(output[0, 0, i, 5] * cols / scale) y2 = int(output[0, 0, i, 6] * rows / scale) boxes.append([x1, y1, x2, y2]) # 合并重叠的文本框 boxes = cv2.dnn.NMSBoxes(boxes, output[:, :, :, 2], ctpn_params['text_proposal_min_score'], ctpn_params['text_proposal_nms_threshold']) # 提取文本行图像 lines = [] for i in boxes: i = i[0] x1, y1, x2, y2 = boxes[i] line = gray[y1:y2, x1:x2] lines.append(line) return lines # 定义文本识别函数 def recognize_text(image): # 缩放图像 img_w, img_h = crnn_params['img_w'], crnn_params['img_h'] image = cv2.resize(image, (img_w, img_h)) # 归一化图像 if np.max(image) > 1: image = image / 255 # 转换图像格式 image = image.transpose([1, 0, 2]) image = np.expand_dims(image, axis=0) # 预测文本 y_pred = crnn_model.predict(image) y_pred = np.argmax(y_pred, axis=2)[0] # 将预测结果转换为文本 text = '' for i in y_pred: if i != len(charset) - 1 and (not (len(text) > 0 and text[-1] == charset[i])): text += charset[i] return text # 读取图像 image = cv2.imread('test.png') # 检测文本行 lines = detect_text(image) # 识别文本 texts = [] for line in lines: text = recognize_text(line) texts.append(text) # 输出识别结果 print(texts) ``` 上述代码实现了一个基于ctpn和crnn的OCR系统,其中ctpn用于检测文本行,crnn用于识别文本内容。在使用代码时,需要将ctpn和crnn的模型文件替换为自己训练的模型文件,并根据实际情况调整参数。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值