这几天一直在用Pytorch来复现文本检测领域的CTPN论文,本文章将从数据处理、训练标签生成、神经网络搭建、损失函数设计、训练主过程编写等这几个方面来一步一步复现CTPN。CTPN算法理论可以参考这里。
训练数据处理
我们的训练选择天池ICPR2018和MSRA_TD500两个数据集,天池ICPR的数据集为网络图像,都是一些淘宝商家上传到淘宝的一些商品介绍图像,其标签方式参考了ICDAR2015的数据标签格式,即一个文本框用4个坐标来表示,即左上、右上、右下、左下四个坐标,共八个值,记作[x1 y1 x2 y2 x3 y3 x4 y4]
天池ICPR2018数据集的风格如下,字体形态格式颜色多变,多嵌套于物体之中,识别难度大:
MSRA_TD500使微软收集的一个文本检测和识别的一个数据集,里面的图像多是街景图,背景比较复杂,但文本位置比较明显,一目了然。因为MSRA_TD500的标签格式不一样,最后一个参数表示矩形框的旋转角度。
所以我们第一步就是将这两个数据集的标签格式统一,我的做法是将MSRA数据集格式改为ICDAR格式,方便后面的模型训练。因为MSRA_TD500采取的标签格式是[index difficulty_label x y w h angle],所以我们需要根据这个文本框的旋转角度来求得水平文本框旋转后的4个坐标位置。实现如下:
"""
This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format.
MSRA_TD500 format: [index difficulty_label x y w h angle]
ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y]
"""
import math
import cv2
import os
# 求旋转后矩形的4个坐标
def get_box_img(x, y, w, h, angle):
# 矩形框中点(x0,y0)
x0 = x + w/2
y0 = y + h/2
l = math.sqrt(pow(w/2, 2) + pow(h/2, 2)) # 即对角线的一半
# angle小于0,逆时针转
if angle < 0:
a1 = -angle + math.atan(h / float(w)) # 旋转角度-对角线与底线所成的角度
a2 = -angle - math.atan(h / float(w)) # 旋转角度+对角线与底线所成的角度
pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2))
pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1))
pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2)) # x0+左下点旋转后在水平线上的投影, y0-左下点在垂直线上的投影,显然逆时针转时,左下点上一和左移了。
pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1))
else:
a1 = angle + math.atan(h / float(w))
a2 = angle - math.atan(h / float(w))
pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1))
pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2))
pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1))
pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2))
return [pt1[0], pt1[1], pt2[0], pt2[1], pt3[0], pt3[1], pt4[0], pt4[1]]
def read_file(path):
result = []
for line in open(path):
info = []
data = line.split(' ')
info.append(int(data[2]))
info.append(int(data[3]))
info.append(int(data[4]))
info.append(int(data[5]))
info.append(float(data[6]))
info.append(data[0])
result.append(info)
return result
if __name__ == '__main__':
file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/'
save_img_path = '../dataset/OCR_dataset/ctpn/test_im/'
save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/'
file_list = os.listdir(file_path)
for f in file_list:
if '.gt' in f:
continue
name = f[0:8]
txt_path = file_path + name + '.gt'
im_path = file_path + f
im = cv2.imread(im_path)
coordinate = read_file(txt_path)
# 仿照ICDAR格式,图片名字写做img_xx.jpg,对应的标签文件写做gt_img_xx.txt
cv2.imwrite(save_img_path + name.lower() + '.jpg', im)
save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w')
for i in coordinate:
box = get_box_img(i[0], i[1], i[2], i[3], i[4])
box = [int(box[i]) for i in range(len(box))]
box = [str(box[i]) for i in range(len(box))]
save_gt.write(','.join(box))
save_gt.write('\n')
经过格式处理后,我们两份数据集算是整理好了。当然我们还需要对整个数据集划分为训练集和测试集,我的文件组织习惯如下:train_im, test_im文件夹装的是训练和测试图像,train_gt和test_gt装的是训练和测试标签。
训练标签生成
因为CTPN的核心思想也是基于Faster RCNN中的region proposal机制的,所以原始数据标签需要转化为
anchor标签。训练数据的标签的生成的代码是最难写,因为从一个完整的文本框标签转化为一个个小尺度文本框标签确实有点难度,而且这个anchor标签的生成方式也与Faster RCNN生成方式略有不同。下面讲一讲我的实现思路:
第一步我们需要将原先每张图的bbox标签转化为每个anchor标签。为了实现该功能,我们先将一张图划分为宽度为16的各个anchor。
- 首先计算一张图可以分为多少个宽度为16的acnhor(比如一张图的宽度为w,那么水平anchor总数为w/16),再计算出我们的文本框标签中含有几个acnhor,最左和最右的anchor又是哪几个;
- 计算文本框内anchor的高度和中心是多少:此时我们可以在一个全黑的mask中把文本框label画上去(白色),然后从上往下和从下往上找到第一个白色像素点的位置作为该anchor的上下边界;
- 最后将每个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回
def generate_gt_anchor(img, box, anchor_width=16):
"""
calsulate ground truth fine-scale box
:param img: input image
:param box: ground truth box (4 point)
:param anchor_width:
:return: tuple (position, h, cy)
"""
if not isinstance(box[0], float):
box = [float(box[i]) for i in range(len(box))]
result = []
# 求解一个bbox下,能分解为多少个16宽度的小anchor,并求出最左和最右的小achor的id
left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards
right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards
# handle extreme case, the right side anchor may exceed the image width
if right_anchor_num * 16 + 15 > img.shape[1]:
right_anchor_num -= 1
# combine the left-side and the right-side x_coordinate of a text anchor into one pair
position_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)]
# 计算每个gt anchor的真实位置,其实就是求解gt anchor的上边界和下边界
y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box)
# 最后将每个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回
for i in range(len(position_pair)):
position = int(position_pair[i][0] / anchor_width) # the index of anchor box
h = y_bottom[i] - y_top[i] + 1 # the height of anchor box
cy = (float(y_bottom[i]) + float(y_top[i])) / 2.0 # the center point of anchor box
result.append((position, cy, h))
return result
计算anchor上下边界的方法:
# cal the gt anchor box's bottom and top coordinate
def cal_y_top_and_bottom(raw_img, position_pair, box):
"""
:param raw_img:
:param position_pair: for example:[(0, 15), (16, 31), ...]
:param box: gt box (4 point)
:return: top and bottom coordinates for y-axis
"""
img = copy.deepcopy(raw_img)
y_top = []
y_bottom = []
height = img.shape[0]
# 设置图像mask,channel 0为全黑图
for i in range(img.shape[0]):
for j in range(img.shape[1]):
img[i, j, 0] = 0
top_flag = False
bottom_flag = False
# 根据bbox四点画出文本框,channel 0下文本框为白色
img = other.draw_box_4pt(img, box, color=(255, 0, 0))
for k in range(len(position_pair)):
# 从左到右遍历anchor gt,对每个anchor从上往下扫描像素,遇到白色像素点(255)就停下来,此时像素点坐标y就是该anchor gt的上边界
# calc top y coordinate
for y in range(0, height-1):
# loop each anchor, from left to right