目录
论文 https://arxiv.org/abs/1609.03605
官方代码(caffe) https://github.com/tianzhi0549/CTPN
tf版本代码 https://github.com/eragonruan/text-detection-ctpn/tree/banjin-dev
网络结构
input: (batch_size, height, width, channel), (1, 896, 608, 3)
backbone: vgg16,conv5_3层输出,stride=16,output_shape = (1, 56, 38, 512)
然后接一个3x3-c512-s1-p1的卷积,输出依旧为(1, 56, 38, 512)。注意这里是tf版本的实现,和官方caffe版本不一样,在论文中作者使用caffe中的im2col实现的滑窗操作,将B × W × H × C大小的feature map转换为B × W × H × 9C,本文遵从tf版本。
然后接双向lstm(关于lstm的介绍见LSTM源码解析),输出shape=(1, 56, 38, 256)。然后reshape再接一个全连接层,再reshape回来shape=(1, 56, 38, 512)
然后进入RPN网络,RPN由两个全连接分支组成,分别用来预测类别和位置,shape分别为(1, 56, 38, 10*2)和(1, 56, 38, 10*4),这里的10是anchor的数量。
至此CTPN网格的整个流程就走完了,接下来开始计算loss,在计算loss前还需先计算target。CTPN最终的RPN和faster-rcnn里的rpn很相似,主要在anchor的生成有一些区别。ctpn里采用10个等宽的anchor,widths=[16],heights=[11,16,23,33,48,68,97,139,198,283],因为网络最终的输出特征图相比于输入stride=16,因此anchor的宽统一取16可以在宽度方向完全覆盖原图且anchor之间互相不重叠。
分类loss采用交叉熵损失,坐标回归loss采用smooth L1 损失。
原文在计算box的回归loss时,只计算anchor中心y和高度的的loss,没有计算anchor中心x坐标和宽度的loss。(注:tf版本中,四个loss都计算了。又注:虽然中心x坐标和宽的loss计算了,但推导阶段并没有用到)
CTPN是由faster-rcnn改进而来,具体来说包括
- 检测文本是二分类问题就只对应faster-rcnn的stage-one,也就是rpn;
- anchor的尺寸和比例由faster-rcnn中的3种scale和3种ratio改为固定宽度,10种高度;
- 考虑到文本的时间序列特点,在rpn前加入了bidirectional-lstm;
- 训练过程只修正anchor中心x坐标和高度;
Inference
接下来具体讲下模型inference的过程,尤其是将text proposal连接成文本框的文本线构造算法
- 图片送入模型得到边框偏差预测bbox_pred_val和类别概率预测cls_prob_val两个feature map,shape分别为(1, 56, 38, 10*4)和(1, 56, 38*10, 2)
- 根据bbox_pred_val对anchor进行修正得到proposals(注意只修正anchor中心y坐标和高,anchor的宽直接取16,anchor的中心x坐标根据anchor位置也可以直接计算出来)
- 对超出图片尺寸的proposals进行clip修正
- 过滤掉那些宽或高小于我们设定的rpn_min_size=8的proposal
- 对剩下的proposal按cls_prob_val从高到低进行排序,取前pre_num_topN=12000个
- 按rpn_nms_thresh=0.7对剩下的proposals进行nms,然后取前post_nms_topN=1000个
- 对剩下的proposals进行文本线构造得到最终结果
文本线构造算法
文本线构造算法具体步骤如下:
- 保留scores(就是cls_prob_val)大于text_proposals_min_score=0.7的proposal,经过这一步proposal数量由1000缩减到141
- 按score排序,按text_proposals_nms_thresh=0.2再次进行nms,经过这一步还剩67个proposals。如下图所示可以看到,到这一步剩下的proposal就是我们最终保留的所有proposal了
-
遍历每个proposal,按水平正方向寻找和水平距离小于max_horizontal_gap=50个像素的其它proposal,因为每个proposal的宽为16,因此最多寻找3个
-
从找到的proposal中挑出和的竖直方向大于min_v_overlaps=0.7的
-
从上一步剩下的proposal中再挑出score最大的,至此和组成一对,最大score记为
-
按水平反方向再重复第3~5步,最终每一对的proposal的score记为
-
建立一个的graph,N是proposal的数量,在这里为67。如果>=,则这是一个最长连接,graph(i, j)=True;如果<,则这不是一个最长连接,即该连接包含在另一个更长的连接中。
下面是对文本线构造的代码做了一些注释:
"""text_proposal_connector.py"""
import numpy as np
from utils.text_connector.other import clip_boxes
from utils.text_connector.text_proposal_graph_builder import TextProposalGraphBuilder
class TextProposalConnector:
def __init__(self):
self.graph_builder = TextProposalGraphBuilder()
def group_text_proposals(self, text_proposals, scores, im_size):
graph = self.graph_builder.build_graph(text_proposals, scores, im_size)
return graph.sub_graphs_connected()
@staticmethod
def fit_y(X, Y, x1, x2):
# len(X) != 0
# if X only include one point, the function will get line y=Y[0]
if np.sum(X == X[0]) == len(X):
return Y[0], Y[0]
# X,Y是长度相同,是多个点的x,y坐标集和
p = np.poly1d(np.polyfit(X, Y, 1)) # 一阶多项式拟合, 0.02054x + 466.6
return p(x1), p(x2)
def get_text_lines(self, text_proposals, scores, im_size):
# tp=text proposal
tp_groups = self.group_text_proposals(text_proposals, scores, im_size)
# [[59, 36, 25, 28, 34, 44, 43, 37, 32, 35, 46, 51, 55, 57, 64],
# [61, 39, 19, 13, 11, 16, 20, 12, 6, 4, 18, 24, 26, 48, 50, 33, 49, 56, 62],
# [63, 40, 17, 1, 3, 0, 2, 7, 10, 9, 23, 31, 38, 45, 52, 60],
# [66, 53, 29, 15, 21, 14, 5, 8, 22, 27, 30, 41, 42, 47, 54, 58, 65]]
text_lines = np.zeros((len(tp_groups), 5), np.float32)
for index, tp_indices in enumerate(tp_groups):
text_line_boxes = text_proposals[list(tp_indices)]
x0 = np.min(text_line_boxes[:, 0])
x1 = np.max(text_line_boxes[:, 2])
offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 # 8.0, proposal长度固定为16
# 取每个proposal的左上角进行直线拟合; 一行proposal最左和最右x坐标向中间移动offset; 根据拟合的直线表达式求移动offset后点的y坐标
lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)
lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)
# the score of a text line is the average score of the scores
# of all text proposals contained in the text line
score = scores[list(tp_indices)].sum() / float(len(tp_indices))
text_lines[index, 0] = x0
text_lines[index, 1] = min(lt_y, rt_y)
text_lines[index, 2] = x1
text_lines[index, 3] = max(lb_y, rb_y)
text_lines[index, 4] = score
text_lines = clip_boxes(text_lines, im_size)
text_recs = np.zeros((len(text_lines), 9), np.float)
index = 0
for line in text_lines:
xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3]
text_recs[index, 0] = xmin
text_recs[index, 1] = ymin
text_recs[index, 2] = xmax
text_recs[index, 3] = ymin
text_recs[index, 4] = xmax
text_recs[index, 5] = ymax
text_recs[index, 6] = xmin
text_recs[index, 7] = ymax
text_recs[index, 8] = line[4]
index = index + 1
return text_recs
"""text_proposal_graph_builder.py"""
import numpy as np
from utils.text_connector.other import Graph
from utils.text_connector.text_connect_cfg import Config as TextLineCfg
class TextProposalGraphBuilder:
"""
Build Text proposals into a graph.
"""
def get_successions(self, index):
box = self.text_proposals[index]
results = []
for left in range(int(box[0]) + 1, min(int(box[0]) + TextLineCfg.MAX_HORIZONTAL_GAP + 1, self.im_size[1])):
adj_box_indices = self.boxes_table[left] # 一个x坐标纵向可能对应多个text_proposal
for adj_box_index in adj_box_indices:
if self.meet_v_iou(adj_box_index, index):
results.append(adj_box_index)
if len(results) != 0:
return results # 从左往右50个像素就是x方向找3个,找到1个后面的就不用找了
return results
def get_precursors(self, index):
box = self.text_proposals[index]
results = []
for left in range(int(box[0]) - 1, max(int(box[0] - TextLineCfg.MAX_HORIZONTAL_GAP), 0) - 1, -1):
adj_box_indices = self.boxes_table[left]
for adj_box_index in adj_box_indices:
if self.meet_v_iou(adj_box_index, index):
results.append(adj_box_index)
if len(results) != 0:
return results
return results
def is_succession_node(self, index, succession_index):
precursors = self.get_precursors(succession_index)
if self.scores[index] >= np.max(self.scores[precursors]):
return True
return False
def meet_v_iou(self, index1, index2):
def overlaps_v(index1, index2):
h1 = self.heights[index1]
h2 = self.heights[index2]
y0 = max(self.text_proposals[index2][1], self.text_proposals[index1][1])
y1 = min(self.text_proposals[index2][3], self.text_proposals[index1][3])
return max(0, y1 - y0 + 1) / min(h1, h2) # 注意这里overlap的计算并不是除以h1和h2纵向距离的并集
def size_similarity(index1, index2):
h1 = self.heights[index1]
h2 = self.heights[index2]
return min(h1, h2) / max(h1, h2)
return overlaps_v(index1, index2) >= TextLineCfg.MIN_V_OVERLAPS and size_similarity(index1, index2) >= TextLineCfg.MIN_SIZE_SIM
def build_graph(self, text_proposals, scores, im_size):
self.text_proposals = text_proposals # (67, 4)
self.scores = scores # (67, 1)
self.im_size = im_size # (896, 608)
self.heights = text_proposals[:, 3] - text_proposals[:, 1] + 1
boxes_table = [[] for _ in range(self.im_size[1])] # 长度就是img的宽度,第i个元素就是所有x0==i的text_proposal的索引组成的列表
for index, box in enumerate(text_proposals):
boxes_table[int(box[0])].append(index)
self.boxes_table = boxes_table
graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool) # (67, 67)
for index, box in enumerate(text_proposals):
successions = self.get_successions(index)
if len(successions) == 0:
continue
succession_index = successions[np.argmax(scores[successions])]
if self.is_succession_node(index, succession_index):
# NOTE: a box can have multiple successions(precursors) if multiple successions(precursors)
# have equal scores.
graph[index, succession_index] = True
return Graph(graph)
"""other.py"""
import numpy as np
def threshold(coords, min_, max_):
return np.maximum(np.minimum(coords, max_), min_)
def clip_boxes(boxes, im_shape):
"""
Clip boxes to image boundaries.
"""
boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1)
boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1)
return boxes
class Graph:
def __init__(self, graph):
self.graph = graph
def sub_graphs_connected(self):
sub_graphs = [] # 两层列表,内层每个列表是一个文本行,内层列表的每个元素是text_proposals的索引
for index in range(self.graph.shape[0]):
if not self.graph[:, index].any() and self.graph[index, :].any(): # 意思是没有以index为终点的最长连接,有以index为起点的最长连接
v = index
sub_graphs.append([v])
while self.graph[v, :].any():
v = np.where(self.graph[v, :])[0][0]
# 以v为起点对应的最长连接的终点index作为起点,再寻找当前起点对应的最长连接
# np.where(self.graph[v, :])=(array([36]),),本身就只有一个元素。[0][0]是为了把值取出来
sub_graphs[-1].append(v)
# 因为前面text_proposals没有按x坐标从小到大排序,所以这里一个文本行的元素不是从小到大排列的,
# 但是索引对应的proposal在原图上的位置是按x坐标从小到大排列的
return sub_graphs