1. 开篇
基于深度学习的文字识别发展到现在,就解码方法来分类,大体上可以分成使用CTC的解码方式以及attention的解码方式,当然落在具体的实现上会有多种多样的细分。前两期介绍的都是基于CTC解码的论文,本期就来介绍一下基于attention的一篇经典论文,其中首次提出了attention drift这个概念,直到现在这个概念还在被引用。代码实现可见:GitHub - chibohe/text_recognition_toolbox: text_recognition_toolbox: The reimplementation of a series of classical scene text recognition papers with Pytorch in a uniform way.
2. 论文解读
2.1 总览
FAN是2017年提出的论文,论文全称是《Focusing Attention: Towards Accurate Text Recognition in Natural Images》。本文的整体架构是基于encoder-decoder的方式,整体流程可见下图。各个模块的功能如下:
-
encoder使用的是resnet,用作视觉特征的提取;
-
decoder使用的是基于attention的RNN,attention用于目标字符和视觉特征的匹配,然后用RNN进行串行的解码;
-
decoder又分为AN(Attention Network)和FN(Focusing Network),前者用作字符和特征的匹配,后者用于匹配的纠偏。
2.2 特征提取层
在之前介绍的两篇论文中,CRNN的encoder使用的是VGG,GRCNN的encoder是RCNN。而本文的encoder使用的是resnet,据论文作者介绍,这是首次在文字识别里引入resnet作为encoder. resnet最开始是用于图像分类的backbone,针对文字识别,作者做了一些微调,具体可见下图的conv4_x的pool和conv5_x里倒数第二个conv层。这样做是为了保留图片水平维度上更多的信息,有利于识别相对较长的文本。
2.3 解码层
在一般的encoder-decoder架构中,我们直接使用attention进行解码输出,也就是下面的AN。但是作者在实验过程中,发现了目标字符和视觉特征不匹配的情况,作者基于这个观察,提出了注意力偏移(attention drift)的概念。下面我们就论文的一个具体例子进行阐释,在下面的(a)图中,当RNN解码到第三步时,即t=3时,按理说我们应该输出字符"K",而要正确的预测出字符"K",我们应该将注意力集中在图片上"K"这个字符的位置。但是在实际过程中,模型却将注意力集中在了"K"偏左方的位置,所以导致模型将"K"误识别成了"1".在加上FN之后,如下面的(b)图所示,我们成功的对注意力进行了纠偏,让焦点落在了"K"中间的位置,从而成功识别出了"K".下面我们具体介绍AN和FN是如何设计的。
2.3.1 AN(Attention Network)
本文的AN和一般的AN并无任何区别,都是Bahdanau attention,核心就在于如何生成注意力向量。首先,我们记encoder产生的特征向量如下图所示,T是沿水平方向的长度,也就是上面resnet里的65,
然后在每个时间步上,我们计算前一个隐藏状态和各个特征向量的关联程度,也就是目标字符和视觉特征的匹配程度,具体公式如下。注意最后需要用softmax进行归一化,
将上面的注意力权重与特征向量进行线性组合,就得到了注意力向量,
接下来就是生成当前时间步隐藏状态,这里就是普通的RNN,
最后用一个前向网络得到最终的字符输出。
2.3.2 FN(Focusing Network)
我们按照论文中作者的写作顺序进行介绍,首先介绍一下感受野,我们都知道经过卷积后计算宽高的公式,这里计算感受野就是这个公式的反向计算。假设我们想计算在L层中(x, y)位置在L-1层的感受野,我们可以通过以下的公式进行计算,这四个点就形成了一个bounding box.然后我们可以反复向下计算,一直计算到第0层,即这个位置在原图的感受野,可以用另一个更大的bounding box表示。
有了以上的计算,针对第j个特征向量,我们就可以得到它在原图中的感受野,然后我们通过下式来表示这一片感受野的中心位置。
结合AN中的注意力权重,我们就可以得到一个在原图的中心点,就是上面那个注意力偏移里的"+"号。记为下式:
接下来我们在原图中裁剪一片区域,这个区域就可以理解为某个字符在原图中对应的视觉特征。这个区域的宽和高就是字符标注的最大宽高,所以在这里我们知道了如果要实现AN,那我们就必须有字符级别的标注。
然后针对这片区域的每个像素,我们都可以计算一个注意力权重,注意e是个k维的向量,k也就是标注数据对应字符集的大小。
再然后还是通过softmax进行归一化。
最后我们将上面的内容包裹进交叉熵损失。
这个focus loss和上面的attention loss进行线性组合就得到最终的loss.
3. 代码解读
这里resnet就不介绍了,基本上跟一般的resnet一样。我们重点看看attention是如何实现的。
首先是如何得到注意力向量,
class AttentionCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embeddings):
super(AttentionCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1)
self.hidden_size = hidden_size
self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
def forward(self, prev_hidden, batch_H, one_hot):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj))
alpha = F.softmax(e, dim=1)
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1)
concat_context = torch.cat([context, one_hot], dim=1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
然后是利用attention和RNN进行串行的解码,
class Attention(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(Attention, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_classes = num_classes
self.generator = nn.Linear(hidden_size, num_classes)
self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)
def _char_one_hot(self, input_char, onehot_dim):
input_char = input_char.unsqueeze(1)
batch_size = input_char.size(0)
one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
one_hot = one_hot.scatter_(1, input_char, 1)
return one_hot
def forward(self, batch_H, text, is_train=True):
batch_size = batch_H.size(0)
num_steps = text.size(1)
output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).zero_().to(device)
hidden = (torch.FloatTensor(batch_size, self.hidden_size).zero_().to(device),
torch.FloatTensor(batch_size, self.hidden_size).zero_().to(device))
if is_train:
for i in range(num_steps):
one_hot = self._char_one_hot(text[:, i], self.num_classes)
hidden, alpha = self.attention_cell(hidden, batch_H, one_hot)
output_hiddens[:, i, :] = hidden[0]
probs = self.generator(output_hiddens)
else:
targets = torch.FloatTensor(batch_size, self.num_classes).zero_().to(device)
probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).zero_().to(device)
for i in range(num_steps):
one_hot = self._char_one_hot(targets, self.num_classes)
hidden, alpha = self.attention_cell(hidden, batch_H, one_hot)
prob = self.generator(hidden[0])
probs[:, i, :] = prob
_, next_input = prob.max(axis=1)
targets = next_input
return probs
4. 收尾
FAN提出了一个重要的概念attention drift,并提出了FN模块来减轻这个问题,其背后的设计思路还是非常符合直觉的。唯一可惜的就是这样做就需要字符级别的标注,而在工业界中几乎是不可能实现的。下一篇我们将介绍基于矫正器TPS的经典论文ASTER.