CTPN简述

目录

网络结构

Inference

文本线构造算法

参考


论文 https://arxiv.org/abs/1609.03605

官方代码(caffe) https://github.com/tianzhi0549/CTPN

tf版本代码 https://github.com/eragonruan/text-detection-ctpn/tree/banjin-dev

网络结构

input: (batch_size, height, width, channel),  (1, 896, 608, 3)

backbone: vgg16,conv5_3层输出,stride=16,output_shape = (1, 56, 38, 512)

然后接一个3x3-c512-s1-p1的卷积,输出依旧为(1, 56, 38, 512)。注意这里是tf版本的实现,和官方caffe版本不一样在论文中作者使用caffe中的im2col实现的滑窗操作,将B × W × H × C大小的feature map转换为B × W × H × 9C,本文遵从tf版本。

然后接双向lstm(关于lstm的介绍见LSTM源码解析),输出shape=(1, 56, 38, 256)。然后reshape再接一个全连接层,再reshape回来shape=(1, 56, 38, 512)

然后进入RPN网络,RPN由两个全连接分支组成,分别用来预测类别和位置,shape分别为(1, 56, 38, 10*2)和(1, 56, 38, 10*4),这里的10是anchor的数量。

至此CTPN网格的整个流程就走完了,接下来开始计算loss,在计算loss前还需先计算target。CTPN最终的RPN和faster-rcnn里的rpn很相似,主要在anchor的生成有一些区别。ctpn里采用10个等宽的anchor,widths=[16],heights=[11,16,23,33,48,68,97,139,198,283],因为网络最终的输出特征图相比于输入stride=16,因此anchor的宽统一取16可以在宽度方向完全覆盖原图且anchor之间互相不重叠。

分类loss采用交叉熵损失坐标回归loss采用smooth L1 损失

原文在计算box的回归loss时,只计算anchor中心y和高度的的loss,没有计算anchor中心x坐标和宽度的loss。(注:tf版本中,四个loss都计算了。又注:虽然中心x坐标和宽的loss计算了,但推导阶段并没有用到

CTPN是由faster-rcnn改进而来,具体来说包括

  1. 检测文本是二分类问题就只对应faster-rcnn的stage-one,也就是rpn;
  2. anchor的尺寸和比例由faster-rcnn中的3种scale和3种ratio改为固定宽度,10种高度;
  3. 考虑到文本的时间序列特点,在rpn前加入了bidirectional-lstm;
  4. 训练过程只修正anchor中心x坐标和高度;

Inference

接下来具体讲下模型inference的过程,尤其是将text proposal连接成文本框的文本线构造算法

  1. 图片送入模型得到边框偏差预测bbox_pred_val和类别概率预测cls_prob_val两个feature map,shape分别为(1, 56, 38, 10*4)和(1, 56, 38*10, 2)
  2. 根据bbox_pred_val对anchor进行修正得到proposals(注意只修正anchor中心y坐标和高,anchor的宽直接取16,anchor的中心x坐标根据anchor位置也可以直接计算出来
  3. 对超出图片尺寸的proposals进行clip修正
  4. 过滤掉那些宽或高小于我们设定的rpn_min_size=8的proposal
  5. 对剩下的proposal按cls_prob_val从高到低进行排序,取前pre_num_topN=12000个
  6. 按rpn_nms_thresh=0.7对剩下的proposals进行nms,然后取前post_nms_topN=1000个
  7. 对剩下的proposals进行文本线构造得到最终结果

文本线构造算法

文本线构造算法具体步骤如下:

  1. 保留scores(就是cls_prob_val)大于text_proposals_min_score=0.7的proposal,经过这一步proposal数量由1000缩减到141
  2. 按score排序,按text_proposals_nms_thresh=0.2再次进行nms,经过这一步还剩67个proposals。如下图所示可以看到,到这一步剩下的proposal就是我们最终保留的所有proposal了

  3. 遍历每个proposal,按水平正方向寻找和proposal_{i}水平距离小于max_horizontal_gap=50个像素的其它proposal,因为每个proposal的宽为16,因此最多寻找3个

  4. 从找到的proposal中挑出和proposal_{i}的竖直方向overlap_{v}大于min_v_overlaps=0.7的

  5. 从上一步剩下的proposal中再挑出score最大的proposal_{j},至此proposal_{i}proposal_{j}组成一对,最大score记为score_{i}

  6. 按水平反方向再重复第3~5步,最终每一对的proposal的score记为score_{k}

  7. 建立一个N\times N的graph,N是proposal的数量,在这里为67。如果score_{i}>=score_{k},则这是一个最长连接,graph(i, j)=True;如果score_{i}<score_{k},则这不是一个最长连接,即该连接包含在另一个更长的连接中。

下面是对文本线构造的代码做了一些注释:

"""text_proposal_connector.py"""

import numpy as np

from utils.text_connector.other import clip_boxes
from utils.text_connector.text_proposal_graph_builder import TextProposalGraphBuilder


class TextProposalConnector:
    def __init__(self):
        self.graph_builder = TextProposalGraphBuilder()

    def group_text_proposals(self, text_proposals, scores, im_size):
        graph = self.graph_builder.build_graph(text_proposals, scores, im_size)
        return graph.sub_graphs_connected()

    @staticmethod
    def fit_y(X, Y, x1, x2):
        # len(X) != 0
        # if X only include one point, the function will get line y=Y[0]
        if np.sum(X == X[0]) == len(X):
            return Y[0], Y[0]
        # X,Y是长度相同,是多个点的x,y坐标集和
        p = np.poly1d(np.polyfit(X, Y, 1))  # 一阶多项式拟合, 0.02054x + 466.6
        return p(x1), p(x2)

    def get_text_lines(self, text_proposals, scores, im_size):
        # tp=text proposal
        tp_groups = self.group_text_proposals(text_proposals, scores, im_size)
        # [[59, 36, 25, 28, 34, 44, 43, 37, 32, 35, 46, 51, 55, 57, 64],
        #  [61, 39, 19, 13, 11, 16, 20, 12, 6, 4, 18, 24, 26, 48, 50, 33, 49, 56, 62],
        #  [63, 40, 17, 1, 3, 0, 2, 7, 10, 9, 23, 31, 38, 45, 52, 60],
        #  [66, 53, 29, 15, 21, 14, 5, 8, 22, 27, 30, 41, 42, 47, 54, 58, 65]]
        text_lines = np.zeros((len(tp_groups), 5), np.float32)

        for index, tp_indices in enumerate(tp_groups):
            text_line_boxes = text_proposals[list(tp_indices)]

            x0 = np.min(text_line_boxes[:, 0])
            x1 = np.max(text_line_boxes[:, 2])

            offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5  # 8.0, proposal长度固定为16

            # 取每个proposal的左上角进行直线拟合; 一行proposal最左和最右x坐标向中间移动offset; 根据拟合的直线表达式求移动offset后点的y坐标
            lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)
            lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)

            # the score of a text line is the average score of the scores
            # of all text proposals contained in the text line
            score = scores[list(tp_indices)].sum() / float(len(tp_indices))

            text_lines[index, 0] = x0
            text_lines[index, 1] = min(lt_y, rt_y)
            text_lines[index, 2] = x1
            text_lines[index, 3] = max(lb_y, rb_y)
            text_lines[index, 4] = score

        text_lines = clip_boxes(text_lines, im_size)

        text_recs = np.zeros((len(text_lines), 9), np.float)
        index = 0
        for line in text_lines:
            xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3]
            text_recs[index, 0] = xmin
            text_recs[index, 1] = ymin
            text_recs[index, 2] = xmax
            text_recs[index, 3] = ymin
            text_recs[index, 4] = xmax
            text_recs[index, 5] = ymax
            text_recs[index, 6] = xmin
            text_recs[index, 7] = ymax
            text_recs[index, 8] = line[4]
            index = index + 1

        return text_recs
"""text_proposal_graph_builder.py"""

import numpy as np

from utils.text_connector.other import Graph
from utils.text_connector.text_connect_cfg import Config as TextLineCfg


class TextProposalGraphBuilder:
    """
        Build Text proposals into a graph.
    """

    def get_successions(self, index):
        box = self.text_proposals[index]
        results = []
        for left in range(int(box[0]) + 1, min(int(box[0]) + TextLineCfg.MAX_HORIZONTAL_GAP + 1, self.im_size[1])):
            adj_box_indices = self.boxes_table[left]  # 一个x坐标纵向可能对应多个text_proposal
            for adj_box_index in adj_box_indices:
                if self.meet_v_iou(adj_box_index, index):
                    results.append(adj_box_index)
            if len(results) != 0:
                return results  # 从左往右50个像素就是x方向找3个,找到1个后面的就不用找了
        return results

    def get_precursors(self, index):
        box = self.text_proposals[index]
        results = []
        for left in range(int(box[0]) - 1, max(int(box[0] - TextLineCfg.MAX_HORIZONTAL_GAP), 0) - 1, -1):
            adj_box_indices = self.boxes_table[left]
            for adj_box_index in adj_box_indices:
                if self.meet_v_iou(adj_box_index, index):
                    results.append(adj_box_index)
            if len(results) != 0:
                return results
        return results

    def is_succession_node(self, index, succession_index):
        precursors = self.get_precursors(succession_index)
        if self.scores[index] >= np.max(self.scores[precursors]):
            return True
        return False

    def meet_v_iou(self, index1, index2):
        def overlaps_v(index1, index2):
            h1 = self.heights[index1]
            h2 = self.heights[index2]
            y0 = max(self.text_proposals[index2][1], self.text_proposals[index1][1])
            y1 = min(self.text_proposals[index2][3], self.text_proposals[index1][3])
            return max(0, y1 - y0 + 1) / min(h1, h2)  # 注意这里overlap的计算并不是除以h1和h2纵向距离的并集

        def size_similarity(index1, index2):
            h1 = self.heights[index1]
            h2 = self.heights[index2]
            return min(h1, h2) / max(h1, h2)

        return overlaps_v(index1, index2) >= TextLineCfg.MIN_V_OVERLAPS and size_similarity(index1, index2) >= TextLineCfg.MIN_SIZE_SIM

    def build_graph(self, text_proposals, scores, im_size):
        self.text_proposals = text_proposals  # (67, 4)
        self.scores = scores  # (67, 1)
        self.im_size = im_size  # (896, 608)
        self.heights = text_proposals[:, 3] - text_proposals[:, 1] + 1

        boxes_table = [[] for _ in range(self.im_size[1])]  # 长度就是img的宽度,第i个元素就是所有x0==i的text_proposal的索引组成的列表
        for index, box in enumerate(text_proposals):
            boxes_table[int(box[0])].append(index)
        self.boxes_table = boxes_table

        graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool)  # (67, 67)

        for index, box in enumerate(text_proposals):
            successions = self.get_successions(index)
            if len(successions) == 0:
                continue
            succession_index = successions[np.argmax(scores[successions])]
            if self.is_succession_node(index, succession_index):
                # NOTE: a box can have multiple successions(precursors) if multiple successions(precursors)
                # have equal scores.
                graph[index, succession_index] = True
        return Graph(graph)
"""other.py"""

import numpy as np


def threshold(coords, min_, max_):
    return np.maximum(np.minimum(coords, max_), min_)


def clip_boxes(boxes, im_shape):
    """
    Clip boxes to image boundaries.
    """
    boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1)
    boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1)
    return boxes


class Graph:
    def __init__(self, graph):
        self.graph = graph

    def sub_graphs_connected(self):
        sub_graphs = []  # 两层列表,内层每个列表是一个文本行,内层列表的每个元素是text_proposals的索引
        for index in range(self.graph.shape[0]):
            if not self.graph[:, index].any() and self.graph[index, :].any():  # 意思是没有以index为终点的最长连接,有以index为起点的最长连接
                v = index
                sub_graphs.append([v])
                while self.graph[v, :].any():
                    v = np.where(self.graph[v, :])[0][0]
                    # 以v为起点对应的最长连接的终点index作为起点,再寻找当前起点对应的最长连接
                    # np.where(self.graph[v, :])=(array([36]),),本身就只有一个元素。[0][0]是为了把值取出来
                    sub_graphs[-1].append(v)
                    # 因为前面text_proposals没有按x坐标从小到大排序,所以这里一个文本行的元素不是从小到大排列的,
                    # 但是索引对应的proposal在原图上的位置是按x坐标从小到大排列的
        return sub_graphs

参考

https://zhuanlan.zhihu.com/p/34757009

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值