PAN++ 端到端场景文本识别【识别部分精讲】

        本文简要介绍了TPAMI2021录用论文“PAN++: Towards Efficient and Accurate End-to-End Spotting of Arbitrarily-Shaped Text”。该论文展示了一种基于文本内核(即中心区域)的任意形状文本的表示方法,可以较好地区分相邻文本,且对实时的应用场景非常友好。在此基础上,作者建立了一个高效的端到端框架PAN++,可以有效地检测和识别自然场景中任意形状的文本,并且同时做到了高推理速度和高精度。

        这是一个端到端的网络,检测识别特征图权重共享,这里只详细讲解识别部分,对文本识别代码不了解的可以看一下。

论文:

https://arxiv.org/pdf/2105.00405.pdficon-default.png?t=M276https://arxiv.org/pdf/2105.00405.pdf

代码:GitHub - whai362/pan_pp.pytorch: Official implementations of PSENet, PAN and PAN++.icon-default.png?t=M276https://github.com/whai362/pan_pp.pytorch

1.识别部分数据集制作

PAN++的识别char字典共有39个,分别是数字10个,字母26个,以及另外三个 EOS PAD UNK

{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'a': 10, 
'b': 11, 'c': 12, 'd': 13, 'e': 14, 'f': 15, 'g': 16, 'h': 17, 'i': 18, 'j': 19, 'k': 20, 
'l': 21, 'm': 22, 'n': 23, 'o': 24, 'p': 25, 'q': 26, 'r': 27, 's': 28, 't': 29, 'u': 30, 
'v': 31, 'w': 32, 'x': 33, 'y': 34, 'z': 35, 'EOS': 36, 'PAD': 37, 'UNK': 38}
<UNK>: 低频词或未在词表中的词
<PAD>: 补全字符
<EOS>: 句子结束标识符

        PAN++规定每一张图片不超过200个文本目标,每一个文本目标不多于32个字符。

        以icdar15为例,首先得到文本的标注信息:

['LEVEL', '###', '###', '###', 'Ion', '###']

        对于每一个张图片文本ground truth制作,先生成(max_word_num + 1, max_word_len)即(201,32)以37(‘PAD’:37)补充的矩阵gt_words:

 gt_words = np.full((self.max_word_num + 1, self.max_word_len),
                           self.char2id['PAD'],
                           dtype=np.int32)

 gt_words.shape = (201, 32)

        同时生成一个word_mask,维度为(max_word_num + 1,)及(201,)的全0矩阵。

word_mask = np.zeros((self.max_word_num + 1, ), dtype=np.int32)

        文本标注写入gt_words之前,先要将所有的大写字母转化为小写字母,从gt_words[1]开始写,第一个要留着。以上面的第一个文本’LEVEL‘为例,该字符有5位,每一位都写上对应的index,第6位写结束符EOS(36),其余的用PAD(37)补齐,碰到不在字典里面的字符,用UNK(38)代替,’LEVEL‘如下显示,写道gt_words的第一个gt_words[1]。然后word_mask[1]写为1.

gt_word = [21 14 31 14 21 36 37 37 37 37 37 37 37 37 37 37 37 37 37 37 37 37 37 37, 37 37 37 37 37 37 37 37]

gt_words[i + 1] = gt_word
word_mask[i + 1] = 1

        对于‘###’的忽略文本,gt_words不做修改, word_mask也不做修改。上述的文本gt_words, word_mask如下:

['LEVEL', '###', '###', '###', 'Ion', '###']

gt_words.shape() = [201, 32]
[[37 37 37 ... 37 37 37], 
 [21 14 31 ... 37 37 37], 
 [37 37 37 ... 37 37 37],
 [18 24 23 36 37... 37 ]
 [37 37 37 ... 37 37 37], 
      ..., 
 [37 37 37 ... 37 37 37], 
 [37 37 37 ... 37 37 37], 
 [37 37 37 ... 37 37 37]]

word_mask.shape() = [201,]
[0 1 0 0 0 1 0 0 0 ... 0]

        由于在网络的训练过程中,需要文本的标注的坐标信息来得到特征图对应的区域,还好用到像素级的实例分割的gt,以及在每一个文本实例区域平行于边界的最小外接矩形。所以还需要gt_instances 和gt_bboxes。

        gt_instances就是每个文本实例的像素级标注;这里先额外解释一下,为什么gt_words、word_masks、gt_bboxes为什么长度为201比图片中最大文本数量还要多1,并且都要从[1]开始写入数据,因为制作gt_instances用到了如下函数:

gt_instance = np.zeros(img.shape[0:2], dtype='uint8')
# 对于每一个文本实例box
cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1)

参数:
cv2.drawContours(image, contours, contourIdx, color, thickness=None, lineType=None, hierarchy=None, maxLevel=None, offset=None)

第一个参数是指明在哪幅图像上绘制轮廓;image为三通道才能显示轮廓
第二个参数是轮廓本身,在Python中是一个list;
第三个参数指定绘制轮廓list中的哪条轮廓,如果是-1,则绘制其中的所有轮廓。
第四个color:如果写颜色则为颜色填充,具体数字,则为该数字填充。
其中thickness表明轮廓线的宽度,如果是-1(cv2.FILLED),则为填充模式。

        可以看到,gt_instances背景全为0,第一个文本实例区域为1,第二个为2,以此类推。

for i in range(1, max_instance + 1):
    ind = gt_instance == i
    if np.sum(ind) == 0:  # 这一行跳过背景,正好从上面gt_instance=1开始运行
        continue
    points = np.array(np.where(ind)).transpose((1, 0))
    tl = np.min(points, axis=0)
    br = np.max(points, axis=0) + 1
    gt_bboxes[i] = (tl[0], tl[1], br[0], br[1])

        gt_bboxes这里是一个维度为(201,4)的矩阵,201与上面gt_words、word_masks维度对应,4表示文本区域的坐上和右下的两个点的xy坐标。

2.识别网络的训练过程

2.1 获取特征图f的crop区域、word标注

        经过预处理后的输入图像大小为736×736,网络经过BackBone,在经过PAN++的FPEM、FFM,得到了特征矩阵f,维度为(1,512,184,184)。

        识别的训练就是用这个特征图f来进行的,首先需要通过文本的标注的坐标信息映射到特征图f上,将对应区域crop下来。

x_crops, gt_words = self.rec_head.extract_feature(
                    f,  # 经过前面特征提取后的网络,shape = (1,512,184,184)
                    (imgs.size(2), imgs.size(3)),  # (736, 736)
                    gt_instances * training_masks, # (除去忽略文本后的像素级别的文本区域标注)
                    gt_bboxes,                     # 左上、右下 两个点的坐标 (201,4)
                    gt_words,                      # char2index后的word gt (201,32)
                    word_masks)                    # word 的mask(201,)

        首先将特征图f上采样到预处理后图片的大小:

x = self.conv(f)                    
x = self.relu(self.bn(x))
x = self._upsample(x, output_size) # (1,128,736,736)

        接着对每一个文本实例平行于边界的矩形框部分做mask,区域部分置1:

t, l, b, r = bboxes_[label]  # 左上右下坐标
mask = (instance_[:, t:b, l:r] == label).float()  # 将平行于边界的矩形框内文本像素置1

        然后就是把它crop下来了:

x_crop = x_[:, t:b, l:r] * mask

        输入识别的网络之前,需要将其归一化,PAN++ 归一化的大小为 8×32,对于h > w * 1.5:的文本实例,则需要h,w换位置,即图像旋转90°:

 _, h, w = x_crop.size()
if h > w * 1.5:
     x_crop = x_crop.transpose(1, 2)
x_crop = F.interpolate(x_crop.unsqueeze(0),
                       self.feature_size,  # (8×32)
                       mode='bilinear')

        最后获得到特征图f的crop区域、word标注:

 # x_crop.shape = (1,128,8,32)
 # gt_words.shape = (1,32)

        最后的x_crop 如下所示: 

 2.2 过网络

        论文如是说:Our recognition head is a seq2seq model equipped with multi-head attention . As shown in Fig. 9, it is composed of a starter and a decoder.

        它包含了一个starter和decoder:

        首先经过starter,它包含一个线性的embedding layer(linear transformation )和一个多头注意力层(multi-head attention layer),经过得到文本的全局特征:

def forward(self, x):
    batch_size, feature_dim, H, W = x.size()  # 5,128,8,32 这一部分的batchsize的大小为一张图片的有标注的文本数量
    x_flatten = x.view(batch_size, feature_dim, H * W).permute(0, 2, 1) # (1,256,128)
    st = x.new_full((batch_size, ), self.START_TOKEN, dtype=torch.long) # (36,36,36,36,36)
    emb_st = self.emb(st)  # 获取EOS 的词向量
    holistic_feature, _ = self.att(emb_st, x_flatten, x_flatten)  # 输入注意力机制,返回全局特征
    return holistic_feature 

# 这里解释一下self.emb()
# self.emb = nn.Embedding(self.vocab_size, self.hidden_dim) # (39, 128)
# 学过NLP的就会一眼看出就是一个大小为39的词向量,每个词向量的维度为128,
# 这个39×128的矩阵是参与训练的,也能加载权重

        继续多头注意力机制矩阵:

holistic_feature, _ = self.att(emb_st, x_flatten, x_flatten)
#                 维度分别为(5,128)(5,256,128)(5,256,128)
#   返回的holistic_feature的维度为 (5,128)

        这里得到的emb_st = self.emb(st)就是论文里面说到的 SOS symbol (one-hot) to a 128-dim vector,而holistic_feature则是SOS经过多头注意力得到的;看论文:

结合代码,这个公式就是如下函数:

if self.training:
    return self.decoder(x, holistic_feature, target)

        F{roi}就是特征图f经过crop的,holistic_feature就是刚刚说的,target就是word的gt(5×128),现在就来到的decoder。

    def forward(self, x, holistic_feature, target):
        # print(x.shape, holistic_feature.shape, target.shape)
        batch_size, feature_dim, H, W = x.size()
        x_flatten = x.view(batch_size, feature_dim, H * W).permute(0, 2, 1)

        max_seq_len = target.size(1)
        h = []
        for i in range(self.num_layers):
            h.append((x.new_zeros((x.size(0), self.hidden_dim),
                                  dtype=torch.float32),
                      x.new_zeros((x.size(0), self.hidden_dim),
                                  dtype=torch.float32)))

        out = x.new_zeros((x.size(0), max_seq_len + 1, self.vocab_size),
                          dtype=torch.float32)
        for t in range(max_seq_len + 1):
            if t == 0:
                xt = holistic_feature
            elif t == 1:
                it = x.new_full((batch_size, ),
                                self.START_TOKEN,
                                dtype=torch.long)
                xt = self.emb(it)
            else:
                it = target[:, t - 2]
                xt = self.emb(it)

            for i in range(self.num_layers):
                if i == 0:
                    inp = xt
                else:
                    inp = h[i - 1][0]
                h[i] = self.lstm_u[i](inp, h[i])
            ht = h[-1][0]
            out_t, _ = self.att(ht, x_flatten, x_flatten)
            # print(out_t.shape, _.shape)
            out_t = torch.cat((out_t, ht), dim=1)
            # print(out_t.shape)
            # exit()
            out_t = self.cls(out_t)
            out[:, t, :] = out_t
        return out[:, 1:, :]

        这个代码LSTM用的nn.LSTMCell,有点绕,感兴趣的自己可以看一下,不过最后返回值out_rec的维度为(5,32,39);

        接下来就是训练的损失函数部分:

loss_rec = self.rec_head.loss(out_rec, gt_words,reduce=False)

#  out_rec (5,32,39)
#  gt_words (5,32)

        out_rec(5,32,39)就是一张图里面有五个 文本实例,32个预测字符,每个字符分别为39个字典里面字符的概率,代码采用交叉熵损失:

    def loss(self, input, target, reduce=True):
        EPS = 1e-6
        N, L, D = input.size()
        mask = target != self.char2id['PAD']
        input = input.contiguous().view(-1, D)
        target = target.contiguous().view(-1)
        loss_rec = F.cross_entropy(input, target, reduce=False)
        loss_rec = loss_rec.view(N, L)
        loss_rec = torch.sum(loss_rec * mask.float(),
                             dim=1) / (torch.sum(mask.float(), dim=1) + EPS)
        acc_rec = acc(torch.argmax(input, dim=1).view(N, L),
                      target.view(N, L),
                      mask,
                      reduce=False)
        if reduce:
            loss_rec = torch.mean(loss_rec)  # [valid]
            acc_rec = torch.mean(acc_rec)
        losses = {'loss_rec': loss_rec, 'acc_rec': acc_rec}

        return losses

        至此,训练部分就分析完了。 

3.识别网络的推理过程

        经过检测的网络之后,得到特征图f和det_res,其中检测结果字典有5类数据,如下:

        此图经过检测部分得到了8个文本实例,经过训练同样步骤的在特征图f上面crop,得到x_crops的维度为(8,128,8,32),输入识别网络,得到识别结果:

words, word_scores = self.rec_head.forward(x_crops)

        分别得到了预测的文本以及置信度,代码设置了置信度的域置,超过则保留,否则舍去。

        至此识别部分讲解完成,文章制作不易,看完记得点赞收藏呀!

  • 4
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值