PSENet原理与代码解析

目录

Pipeline

Label Generation

Loss

PSE算法


论文 https://arxiv.org/abs/1806.02559

官方源码 GitHub - whai362/PSENet: Official Pytorch implementations of PSENet.

对形状鲁棒的文本检测存在两个挑战:

  • 大多数现有的基于bounding box的检测方法不适用于对任意形状的文本(如弯曲文本)的检测
  • 大多数现有的基于像素分割的检测方法可能无法将挨得非常近的文本区分开

为了解决这两个问题,作者提出了Pogressive Scale Expansion Network(PSENet)。PSENet是一个基于像素分割的检测方法,对于每个文本实例网络输出多个预测结果,每个预测结果对应的ground truth是将完整的文本标注区域按不同的比例缩放得到的多个“核”。然后通过PSE算法将网络输出的结果从小到大扩充得到最终预测结果。即使对于挨得很近的文本实例,将他们缩放后得到的核之间距离也足够大,因次可以很好地将挨得近的文本实例区分开,同时因为是基于像素分割的网络,对于任意形状的文本检测效果很鲁棒。

如上图,图(a)是原始图像。图(b)是基于bounding box检测方法的检测结果,因为检测框无法很好地拟合弯曲文本,导致某一文本实例的检测结果会覆盖其它文本实例。图(c)是传统基于像素分割的检测结果,因为前三行挨得太近导致网络将三个文本实例检测成一个。图(d)是PSENet也正是我们期望的检测结果。

Pipeline

1.  左边的蓝色部分为backbone,论文中是ResNet50,从下到上分别为ResNet50的conv_2x、conv_3x、conv_4x、conv_5x的输出C2、C3、C4、C5。原始输入为640*640*3,C2到C5的shape分别为160*160*256、80*80*512、40*40*1024、20*20*2048。

2.  中间的橘色部分作者借鉴的FPN,以第一个节点为例

       20*20*2048的C5经过1*1*256的Conv、BN、ReLU得到20*20*256的P5

       40*40*1024的C4经过1*1*256的Conv、BN、ReLU得到40*40*256的C4

       P5经过upsample后与C4进行add得到40*40*256的中间结果

       然后经过3*3*256的Conv、BN、ReLU得到40*40*256的P4

  最终P2~P5的shape分别为160*160*256、80*80*256、40*40*256、20*20*256

3.  P5、P4、P3分别upsample成P2的大小,然后concatenate得到160*160*1024的fusion feature即图中的F

4.  然后经过3*3*256的Conv、BN、ReLU、1*1*num_class(论文中num_class=6,包含1个完整的文本标注)得到7*160*160的特征图

5.  最后以stride=4,upsample成输入大小得到最终结果。模型的输入shape为(batch_size,3,640,640),输出shape为(batch_size,7,640,640)

Label Generation

图(a)(b)中的蓝色框p_{n}为文本的原始也是完整的标注信息,对应于最大的segmentation label mask。通过将原始框向内缩放d_{i}个像素得到对应的p_{i}。将p_{i}转化成0/1二值图即得到了对应不同预测kernel的ground truth。

d_{i}的计算公式如下

其中p_{n}是原始标注框,Area和Perimeter分别是面积和周长,r_{i}的计算公式如下

其中m是最终缩放比例,n是kernel个数,论文中取m=0.5,n=6。i的取值范围为[1,n],p_{1}是最小的kenel,p_{6}是最大的kernel。

注意这里官方代码和论文有出入,代码中m=0.4, n=7,且计算r_{i}的时候最后是乘以n而不是n-i。但只是生成的gt顺序相反,这里p_{1}是最小的kernel,p_{6}是最大,代码相反,并不影响训练。

Loss

L_{c}L_{s}分别代表完整的文本实例的loss和缩放的文本实例的loss,\lambda是比例系数,论文中取0.7

其中S_{i,x,y}G_{i,x,y}分别对应网络输出S_{i}和对应的ground truth G_{i}中(x,y)处的值

因为图片中会有许多和文字笔画比较像的物体比如栅栏、格子图案等,因此对L_{c}使用在线困难样本挖掘来更好的区分文本和这些相似的非文本。L_{c}专注于区分文本和非文本。

M是采用了Online Hard Example Mining后得到的mask

def ohem_single(score, gt_text, training_mask):
    pos_num = int(np.sum(gt_text > 0.5)) - int(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
    # 去掉###的gt (training_mask在标签为###的位置值为0,其它位置为1)

    neg_num = int(np.sum(gt_text <= 0.5))
    neg_num = int(min(pos_num * 3, neg_num))

    neg_score = score[gt_text <= 0.5]
    # 模型预测结果图score在gt为背景的部分得分最大的neg_num个像素作为负样本

    neg_score_sorted = np.sort(-neg_score)  # 从小到大
    threshold = -neg_score_sorted[neg_num - 1]

    selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
    selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
    return selected_mask

缩放的文本实例L_{s}的计算公式如下

因为shrunk kernels是被完整的文本实例包围的,因此在计算L_{s}时忽略网络输出中非文本的区域,即这里的W_{x,y},来avoid a certain redundancy。

注意代码中L_{s}计算公式如下,但是最终结果和论文中是一样的

                           L_{s} = \frac{\sum_{i=1}^{n-1}(1-D(S_{i}\cdot W, G_{i}\cdot W))}{n-1}

PSE算法

def pse(kernals, min_area):
    kernal_num = len(kernals) 
    pred = np.zeros(kernals[0].shape, dtype='int32')

    label_num, label = cv2.connectedComponents(kernals[kernal_num - 1], connectivity=4)
    # 找网络输出最小kernel中的连通域。
    # 这里的label_num是连通域个数(包含背景)。label是和输入大小相同的图,若一共5个联通域,则label的背景部分值为0,5个连通域的值分别为1~5

    for label_idx in range(1, label_num):
        if np.sum(label == label_idx) < min_area:
            label[label == label_idx] = 0

    queue = Queue.Queue(maxsize=0)  # 先进先出
    next_queue = Queue.Queue(maxsize=0)
    points = np.array(np.where(label > 0)).transpose((1, 0))  # (10234, 2)所有连通域内的点

    for point_idx in range(points.shape[0]):
        x, y = points[point_idx, 0], points[point_idx, 1]  # 注意这里x是第一维对应的是图像的高
        l = label[x, y]
        queue.put((x, y, l))
        pred[x, y] = l

    dx = [-1, 1, 0, 0]
    dy = [0, 0, -1, 1]
    for kernal_idx in range(kernal_num - 2, -1, -1):
        kernal = kernals[kernal_idx].copy()  # 从上一个kernel即queue里存的值往当前kernel扩充
        while not queue.empty():
            (x, y, l) = queue.get()

            is_edge = True
            for j in range(4):
                tmpx = x + dx[j]
                tmpy = y + dy[j]
                if tmpx < 0 or tmpx >= kernal.shape[0] or tmpy < 0 or tmpy >= kernal.shape[1]:
                    continue
                if kernal[tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:  # 当前kernel这个点像素为0或者已经扩充过了
                    continue

                queue.put((tmpx, tmpy, l))
                pred[tmpx, tmpy] = l
                is_edge = False
            if is_edge:
                next_queue.put((x, y, l))

        queue, next_queue = next_queue, queue
        # next_queue里存的是is_edge的像素点,queue是空的。然后下一轮只从is_edge即最外层轮廓开始expand

    return pred

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值