文字识别领域经典论文回顾第三期:FAN

1. 开篇

基于深度学习的文字识别发展到现在,就解码方法来分类,大体上可以分成使用CTC的解码方式以及attention的解码方式,当然落在具体的实现上会有多种多样的细分。前两期介绍的都是基于CTC解码的论文,本期就来介绍一下基于attention的一篇经典论文,其中首次提出了attention drift这个概念,直到现在这个概念还在被引用。代码实现可见:GitHub - chibohe/text_recognition_toolbox: text_recognition_toolbox: The reimplementation of a series of classical scene text recognition papers with Pytorch in a uniform way.

2. 论文解读

2.1 总览

FAN是2017年提出的论文,论文全称是《Focusing Attention: Towards Accurate Text Recognition in Natural Images》。本文的整体架构是基于encoder-decoder的方式,整体流程可见下图。各个模块的功能如下:

  1. encoder使用的是resnet,用作视觉特征的提取;

  2. decoder使用的是基于attention的RNN,attention用于目标字符和视觉特征的匹配,然后用RNN进行串行的解码;

  3. decoder又分为AN(Attention Network)和FN(Focusing Network),前者用作字符和特征的匹配,后者用于匹配的纠偏。

2.2 特征提取层

在之前介绍的两篇论文中,CRNN的encoder使用的是VGG,GRCNN的encoder是RCNN。而本文的encoder使用的是resnet,据论文作者介绍,这是首次在文字识别里引入resnet作为encoder. resnet最开始是用于图像分类的backbone,针对文字识别,作者做了一些微调,具体可见下图的conv4_x的pool和conv5_x里倒数第二个conv层。这样做是为了保留图片水平维度上更多的信息,有利于识别相对较长的文本。

2.3 解码层

在一般的encoder-decoder架构中,我们直接使用attention进行解码输出,也就是下面的AN。但是作者在实验过程中,发现了目标字符和视觉特征不匹配的情况,作者基于这个观察,提出了注意力偏移(attention drift)的概念。下面我们就论文的一个具体例子进行阐释,在下面的(a)图中,当RNN解码到第三步时,即t=3时,按理说我们应该输出字符"K",而要正确的预测出字符"K",我们应该将注意力集中在图片上"K"这个字符的位置。但是在实际过程中,模型却将注意力集中在了"K"偏左方的位置,所以导致模型将"K"误识别成了"1".在加上FN之后,如下面的(b)图所示,我们成功的对注意力进行了纠偏,让焦点落在了"K"中间的位置,从而成功识别出了"K".下面我们具体介绍AN和FN是如何设计的。

2.3.1 AN(Attention Network)

本文的AN和一般的AN并无任何区别,都是Bahdanau attention,核心就在于如何生成注意力向量。首先,我们记encoder产生的特征向量如下图所示,T是沿水平方向的长度,也就是上面resnet里的65,

然后在每个时间步上,我们计算前一个隐藏状态和各个特征向量的关联程度,也就是目标字符和视觉特征的匹配程度,具体公式如下。注意最后需要用softmax进行归一化,

 

将上面的注意力权重与特征向量进行线性组合,就得到了注意力向量,

 

接下来就是生成当前时间步隐藏状态,这里就是普通的RNN,

 

最后用一个前向网络得到最终的字符输出。

 

2.3.2 FN(Focusing Network)

我们按照论文中作者的写作顺序进行介绍,首先介绍一下感受野,我们都知道经过卷积后计算宽高的公式,这里计算感受野就是这个公式的反向计算。假设我们想计算在L层中(x, y)位置在L-1层的感受野,我们可以通过以下的公式进行计算,这四个点就形成了一个bounding box.然后我们可以反复向下计算,一直计算到第0层,即这个位置在原图的感受野,可以用另一个更大的bounding box表示。

有了以上的计算,针对第j个特征向量,我们就可以得到它在原图中的感受野,然后我们通过下式来表示这一片感受野的中心位置。

 

结合AN中的注意力权重,我们就可以得到一个在原图的中心点,就是上面那个注意力偏移里的"+"号。记为下式:

 

接下来我们在原图中裁剪一片区域,这个区域就可以理解为某个字符在原图中对应的视觉特征。这个区域的宽和高就是字符标注的最大宽高,所以在这里我们知道了如果要实现AN,那我们就必须有字符级别的标注。

 

然后针对这片区域的每个像素,我们都可以计算一个注意力权重,注意e是个k维的向量,k也就是标注数据对应字符集的大小。

 

再然后还是通过softmax进行归一化。

 

最后我们将上面的内容包裹进交叉熵损失。

 

这个focus loss和上面的attention loss进行线性组合就得到最终的loss.

3. 代码解读

这里resnet就不介绍了,基本上跟一般的resnet一样。我们重点看看attention是如何实现的。

首先是如何得到注意力向量,

class AttentionCell(nn.Module):
    def __init__(self, input_size, hidden_size, num_embeddings):
        super(AttentionCell, self).__init__()
        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.score = nn.Linear(hidden_size, 1)
        self.hidden_size = hidden_size
        self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)

    def forward(self, prev_hidden, batch_H, one_hot):
        batch_H_proj = self.i2h(batch_H)
        prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
        e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj))

        alpha = F.softmax(e, dim=1)
        context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1)
        concat_context = torch.cat([context, one_hot], dim=1)
        cur_hidden = self.rnn(concat_context, prev_hidden)

        return cur_hidden, alpha

 然后是利用attention和RNN进行串行的解码,

class Attention(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Attention, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.generator = nn.Linear(hidden_size, num_classes)
        self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)

    def _char_one_hot(self, input_char, onehot_dim):
        input_char = input_char.unsqueeze(1)
        batch_size = input_char.size(0)
        one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
        one_hot = one_hot.scatter_(1, input_char, 1)
        return one_hot

    def forward(self, batch_H, text, is_train=True):
        batch_size = batch_H.size(0)
        num_steps = text.size(1)

        output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).zero_().to(device)
        hidden = (torch.FloatTensor(batch_size, self.hidden_size).zero_().to(device),
                  torch.FloatTensor(batch_size, self.hidden_size).zero_().to(device))
        
        if is_train:
            for i in range(num_steps):
                one_hot = self._char_one_hot(text[:, i], self.num_classes)
                hidden, alpha = self.attention_cell(hidden, batch_H, one_hot)
                output_hiddens[:, i, :] = hidden[0]
                probs = self.generator(output_hiddens)
        else:
            targets = torch.FloatTensor(batch_size, self.num_classes).zero_().to(device)
            probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).zero_().to(device)
            
            for i in range(num_steps):
                one_hot = self._char_one_hot(targets, self.num_classes)
                hidden, alpha = self.attention_cell(hidden, batch_H, one_hot)
                prob = self.generator(hidden[0])
                probs[:, i, :] = prob
                _, next_input = prob.max(axis=1)
                targets = next_input
        
        return probs

 

4. 收尾

FAN提出了一个重要的概念attention drift,并提出了FN模块来减轻这个问题,其背后的设计思路还是非常符合直觉的。唯一可惜的就是这样做就需要字符级别的标注,而在工业界中几乎是不可能实现的。下一篇我们将介绍基于矫正器TPS的经典论文ASTER.

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值