【文本识别系列】Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition

论文原文:【链接】
解读代码:【链接】

在这里插入图片描述
个人体会: 这个工作有点像是做了一个伪字符级别(定长处理)的结果,在处理上,通过垂直方向的缩小,在水平方向上进行字符级的注意力机制和序列预测。全文的主要工作在于做了一个注意力机制的2D-attention,并且在decoder中进行了运用。实验结果上看可以在不规则文本的数据集上取得良好的性能。

encoder部分

在这里插入图片描述
这里不同于常见的encoder论文,这里只取了最后一个特征信息,称之为Holistic Feature (1*1*c),这个做法很独特。

@ENCODERS.register_module()
class SAREncoder(BaseEncoder):
    def __init__(self):
        super().__init__()
        # LSTM Encoder
        if enc_gru:
            self.rnn_encoder = nn.GRU(**kwargs)
        else:
            self.rnn_encoder = nn.LSTM(**kwargs)

        # global feature transformation
        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)

    def init_weights(self):
        # initialize weight and bias

    def forward(self, feat, img_metas=None):
        h_feat = feat.size(2)
        #step 1 :压缩了水平方向的向量,维度从(b,c,w,h) -> (b,c,w,h) -> (b,c,w) -> (b,w,c) 
        feat_v = F.max_pool2d(
            feat, kernel_size=(h_feat, 1), stride=1, padding=0)
        feat_v = feat_v.squeeze(2)  # bsz * C * W
        feat_v = feat_v.permute(0, 2, 1).contiguous()  # bsz * W * C
		
		#step 2 : 将1D序列放入循环序列网络中,提取上下文关系依赖
        holistic_feat = self.rnn_encoder(feat_v)[0]  # bsz * T * C
		
		#step 3 : 提取序列中的最后一个
        valid_hf = holistic_feat[:, -1, :]  # bsz * C

        holistic_feat = self.linear(valid_hf)  # bsz * C

        return holistic_feat

decoder部分

在这里插入图片描述
作者把这个Holistic Feature的片段作为开始序列,(这里我理解为start字符的开始位置为特征信息所做的解码,这里用编码的部分作为start去模拟这一方式,个人理解),之后使用ground truth作为后序输入信息(图中y所示),代码中将两个部分进行了cat拼接操作。

   def forward_train(self, feat, out_enc, targets_dict, img_metas):
        #step 1 : 将target部分词嵌入,并输入到对应的编码器中
        targets = targets_dict['padded_targets'].to(feat.device)
        tgt_embedding = self.embedding(targets)
        # bsz * seq_len * emb_dim
		
		#step 2 : 将encoder输出的部分调整成相应维度并进行拼接
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        in_dec = torch.cat((out_enc, tgt_embedding), dim=1)
        # bsz * (seq_len + 1) * C

		#step 3 : 送入2Dattention机制中
        out_dec = self._2d_attention(
            in_dec, feat, out_enc, valid_ratios=valid_ratios)
        # bsz * (seq_len + 1) * num_classes

        return out_dec[:, 1:, :]  # bsz * seq_len * num_classes

    def forward_test(self, feat, out_enc, img_metas):
        seq_len = self.max_seq_len

        bsz = feat.size(0)

		#step 1 : 用开始字符去替代target,并进行同样处理
        start_token = torch.full((bsz, ),
                                 self.start_idx,
                                 device=feat.device,
                                 dtype=torch.long)
        # bsz
        start_token = self.embedding(start_token)
        # bsz * emb_dim
        start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)

		#step 2 : 将encoder部分作为第一个序列,进行输入
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        decoder_input = torch.cat((out_enc, start_token), dim=1)
        # bsz * (seq_len + 1) * emb_dim

        #step 3 : 通过循环的方式不断输入到2d attention中,进行预测
        outputs = []
        for i in range(1, seq_len + 1):
            decoder_output = self._2d_attention(
                decoder_input, feat, out_enc, valid_ratios=valid_ratios)
            char_output = decoder_output[:, i, :]  # bsz * num_classes
            char_output = F.softmax(char_output, -1)
            outputs.append(char_output)
            _, max_idx = torch.max(char_output, dim=1, keepdim=False)
            char_embedding = self.embedding(max_idx)  # bsz * emb_dim
            if i < seq_len:
                decoder_input[:, i + 1, :] = char_embedding

        outputs = torch.stack(outputs, 1)  # bsz * seq_len * num_classes

        return outputs

2D-attention部分:


这里的2D-attention其实就是一个注意力机制,Hidden_state作为一个Query,Feature Map作为一个Key和Value存在。(这里的hidden states 第一个字符为encoder的输出Holistic Feature,之后的为解码器的LSTM的输出)注意力公式可以列成

step 1: 加性模型
a = tanh(W Hidden_state + W Feature Map )
step 2 :注意力权重
a = softmax(a) 
step 3 : 相乘(Feature Map可以理解成Value)
output = a* Feature Map

如果从K,Q,V的角度来理解这个注意力机制,即用decoder的结果去做CNN feature Map的attention机制,即(CNN encoder decoder )+CNN 做attention
在这里插入图片描述

    def _2d_attention(self,
                      decoder_input,
                      feat,
                      holistic_feat,
                      valid_ratios=None):
        y = self.rnn_decoder(decoder_input)[0]
        # y: bsz * (seq_len + 1) * hidden_size

        #step 1 : 如图所示,通过卷积,获得query(hidden states)
        #(这里的hidden states 第一个字符为encoder的输出Holistic Feature,之后的为解码器的LSTM的输出)
        attn_query = self.conv1x1_1(y)  # bsz * (seq_len + 1) * attn_size
        bsz, seq_len, attn_size = attn_query.size()
        attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1)

		#step 2: feat 为backbone的输出,看成key,这里作为attention的feature map出现
        attn_key = self.conv3x3_1(feat)
        # bsz * attn_size * h * w
        attn_key = attn_key.unsqueeze(1)
        # bsz * 1 * attn_size * h * w

		#step 3 : 做加性模型的注意力机制
        attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
        # bsz * (seq_len + 1) * attn_size * h * w
        attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous()
        # bsz * (seq_len + 1) * h * w * attn_size
        attn_weight = self.conv1x1_2(attn_weight)
        # bsz * (seq_len + 1) * h * w * 1
        bsz, T, h, w, c = attn_weight.size()
        assert c == 1
		
		#掩码部分,去填充图谱未缩放所造成的问题
        if valid_ratios is not None:
            # cal mask of attention weight
            attn_mask = torch.zeros_like(attn_weight)
            for i, valid_ratio in enumerate(valid_ratios):
                valid_width = min(w, math.ceil(w * valid_ratio))
                attn_mask[i, :, :, valid_width:, :] = 1
            attn_weight = attn_weight.masked_fill(attn_mask.bool(),
                                                  float('-inf'))
		#制作掩码部分
        attn_weight = attn_weight.view(bsz, T, -1)
        attn_weight = F.softmax(attn_weight, dim=-1)
        attn_weight = attn_weight.view(bsz, T, h, w,
                                       c).permute(0, 1, 4, 2, 3).contiguous()
		#step 3 : 和掩码相加,并作为注意力权重进行相乘
        attn_feat = torch.sum(
            torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False)
        # bsz * (seq_len + 1) * C
		
		#step 4 : 预测最终结果
        # linear transformation
        if self.pred_concat:
            hf_c = holistic_feat.size(-1)
            holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c)
            y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2))
        else:
            y = self.prediction(attn_feat)
        # bsz * (seq_len + 1) * num_classes
        if self.train_mode:
            y = self.pred_dropout(y)

        return y
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值