文字识别领域经典论文回顾第四期:ASTER

1. 开篇

在之前介绍的三篇论文,处理的对象都是水平的文本,常见于各类票据卡证等。而在自然场景中,因为图片本身属性的问题,加上拍摄角度的不同,往往会造成图片里文字的扭曲,主要包括弯曲、透视、多方向等问题。为了解决此类问题,本文提出了一类基于矫正器的文字识别方法。代码可以参考:https://github.com/ayumiymk/aster.pytorch

2. 论文解读

2.1 总览

ASTER是2018年提出的论文,论文的全称是《ASTER: An Attentional Scene Text Recognizer with Flexible Rectification》。本文跟之前的FAN一样,仍然是基于encoder-decoder的方式,整体的模型架构以下三块:

  1. TPS(Thin-Plate-Spline):分为localization network和grid sampler,前者用于回归出控制点,后者用于在原图上进行网格采样;

  2. encoder:卷积神经网络用的是resnet,语言模型使用的是BiLSTM,需要说明的是在后续的DTRB论文中语言模型会单独拆分出来,在这里还是和原论文保持一致;

  3. decoder:使用的是基于bahdanau attention的decoder,这里用了两个LSTM decoder。一个从左到右,一个从右到左,进行双向的解码。

2.2 矫正器

从模型结构的总览可以看出,ASTER其实和FAN有诸多的相似之处,最大的不同就在于TPS模块。所以,我们就重点介绍一下这个模块究竟是怎么实现文字的矫正的。首先我们看一下TPS的整体结构,对于形状为(N,C,H_in,W_in)的输入图像I,经过下采样得到I_d,然后通过localization network得到控制点C’。有了C‘我们可以通过TPS得到一个矩阵变换T,接下来我们通过grid generator得到网格P,形状为 (N, H_out, W_out, 2),最后一维的2代表xy。接下来我们通过矩阵变换T将网格P映射至原图上得到P’,形状仍然为 (N, H_out, W_out, 2)。最后根据原图的网格P'采样得到I_r.下面我们进行一一讲解。

2.2.1 Localization Network

localization network就是一个卷积神经网络,里面都是3x3的conv block,最终通过全连接层得到控制点C‘,形状为(20, 2). 20代表上下各10个点,第二维是xy坐标。在这里需要注意全连接层的数值初始化的问题。作者通过对比试验证明了,当全连接层的偏置项初始化为[(0.01, 0.01), (0.02, 0.01), ..., (0.01, 0.99), ..., (0.99, 0.99)]时,即在图片的上下边缘等距采样时,模型收敛的速度更快。

2.2.2 Thin Plate Transformation

由localization network我们得到了C’,然后我们同样用等距采样得到C,C的形状跟C‘一致,但是每两点的距离不是0.01,而是0.05.接下来我们通过如下的矩阵运算得到变换矩阵T,

2.2.4 Sampler

首先利用grid generator得到网格P,然后通过下式我们将P映射到原图的P’.注意P和P‘数值范围都在0到1之间,但在最终进行插值输出的过程中,我们会将P’映射到-1到1之间,这个会在下面的代码看出。

在这里我们稍微总结一下。从下图可以看出,其实TPS就是要得到一个变换矩阵,其中C‘是需要进行学习的参数,而C是不变的,即手动调整的参数。根据C和C’可以得到T,然后在原图上采样就得到最终矫正之后的图像。

 

2.3 特征提取层

本文的特征提取层跟上一篇的FAN一致,都是先经过resnet,然后经过双向的LSTM,最终得到形状为(B, W, C)的三维特征向量,其中B代表batch size, W是time steps,C是channels.比如说根据原文,当输入大小为(32, 100)时,输出就是(B, 25, 512)

2.4 解码层

本文的解码层和FAN基本类似,但有两处改进。第一点是将原先FAN的单向attention解码改成了双向的attention解码,这点改进的出发点是非常直观的。比如当解码到一个特定的字符时,该字符不仅与左边的语义信息相关,也与右边的相关。双向解码具体的做法如下,分别从左到右以及从右到左进行解码输出,然后去log-softmax得分高的作为最终的输出。这里使用的attention与FAN中的一致,都是bahdanau attention,具体公式就不赘述了。

第二处改进是在最终预测输出的时候,原先我们一般取每个时间步概率最大的字符进行输出,本文改成了束搜索,搜索宽度一般设置成5

3. 代码解读

我们重点看看TPS以及attention decoder,这里的attention decoder用的还是单向的。如果想改成双向的话,直接将(B, L, C)中L的顺序改为从右至左就行。

3.1 TPS

首先我们看看如何回归得到C‘,注意是如何对最后一个全连接层进行初始化的。

def conv3x3_block(in_planes, out_planes, stride=1):
  """3x3 convolution with padding"""
  conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1)

  block = nn.Sequential(
    conv_layer,
    nn.BatchNorm2d(out_planes),
    nn.ReLU(inplace=True),
  )
  return block


class STNHead(nn.Module):
  def __init__(self, in_planes, num_ctrlpoints, activation='none'):
    super(STNHead, self).__init__()

    self.in_planes = in_planes
    self.num_ctrlpoints = num_ctrlpoints
    self.activation = activation
    # 一路都是3x3的conv block,中间用max pooling将宽高各减一半
    self.stn_convnet = nn.Sequential(
                          conv3x3_block(in_planes, 32), # 32*64
                          nn.MaxPool2d(kernel_size=2, stride=2),
                          conv3x3_block(32, 64), # 16*32
                          nn.MaxPool2d(kernel_size=2, stride=2),
                          conv3x3_block(64, 128), # 8*16
                          nn.MaxPool2d(kernel_size=2, stride=2),
                          conv3x3_block(128, 256), # 4*8
                          nn.MaxPool2d(kernel_size=2, stride=2),
                          conv3x3_block(256, 256), # 2*4,
                          nn.MaxPool2d(kernel_size=2, stride=2),
                          conv3x3_block(256, 256)) # 1*2

    self.stn_fc1 = nn.Sequential(
                      nn.Linear(2*256, 512),
                      nn.BatchNorm1d(512),
                      nn.ReLU(inplace=True))
    self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2)

    self.init_weights(self.stn_convnet)
    self.init_weights(self.stn_fc1)
    self.init_stn(self.stn_fc2)

	# 对全连接层stn_fc2进行初始化,间隔为0.01
  def init_stn(self, stn_fc2):
    margin = 0.01
    sampling_num_per_side = int(self.num_ctrlpoints / 2)
    ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side)
    ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
    ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin)
    ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
    ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
    ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
    if self.activation is 'none':
      pass
    elif self.activation == 'sigmoid':
      ctrl_points = -np.log(1. / ctrl_points - 1.)
    stn_fc2.weight.data.zero_()
    stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)

3.2 attention decoder

这个实现是用GRU进行解码的,而FAN里使用的是LSTM。另外这个实现是将输入(B, L, W)中的L变成1,所以可以直接用GRU,而不是GRUCell进行解码。但其实我觉得用GRUCell解码更为直观一些。

class AttentionRecognitionHead(nn.Module):
  """
  input: [b x 16 x 64 x in_planes]
  output: probability sequence: [b x T x num_classes]
  """
  def __init__(self, num_classes, in_planes, sDim, attDim, max_len_labels):
    super(AttentionRecognitionHead, self).__init__()
    self.num_classes = num_classes # this is the output classes. So it includes the <EOS>.
    self.in_planes = in_planes
    self.sDim = sDim
    self.attDim = attDim
    self.max_len_labels = max_len_labels

    self.decoder = DecoderUnit(sDim=sDim, xDim=in_planes, yDim=num_classes, attDim=attDim)

  def forward(self, x):
    x, targets, lengths = x
    batch_size = x.size(0)
    # Decoder
    # 注意这里的1就是时间步
    state = torch.zeros(1, batch_size, self.sDim)
    outputs = []

    for i in range(max(lengths)):
      if i == 0:
        y_prev = torch.zeros((batch_size)).fill_(self.num_classes) # the last one is used as the <BOS>.
      else:
        y_prev = targets[:,i-1]

      output, state = self.decoder(x, state, y_prev)
      outputs.append(output)
    outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1)
    return outputs
  
  
  class AttentionUnit(nn.Module):
  def __init__(self, sDim, xDim, attDim):
    super(AttentionUnit, self).__init__()

    self.sDim = sDim
    self.xDim = xDim
    self.attDim = attDim

    self.sEmbed = nn.Linear(sDim, attDim)
    self.xEmbed = nn.Linear(xDim, attDim)
    self.wEmbed = nn.Linear(attDim, 1)


  def forward(self, x, sPrev):
    batch_size, T, _ = x.size()                      # [b x T x xDim]
    x = x.view(-1, self.xDim)                        # [(b x T) x xDim]
    xProj = self.xEmbed(x)                           # [(b x T) x attDim]
    xProj = xProj.view(batch_size, T, -1)            # [b x T x attDim]

    sPrev = sPrev.squeeze(0)
    sProj = self.sEmbed(sPrev)                       # [b x attDim]
    sProj = torch.unsqueeze(sProj, 1)                # [b x 1 x attDim]
    sProj = sProj.expand(batch_size, T, self.attDim) # [b x T x attDim]

    sumTanh = torch.tanh(sProj + xProj)
    sumTanh = sumTanh.view(-1, self.attDim)

    vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
    vProj = vProj.view(batch_size, T)

    alpha = F.softmax(vProj, dim=1) # attention weights for each sample in the minibatch

    return alpha


class DecoderUnit(nn.Module):
  def __init__(self, sDim, xDim, yDim, attDim):
    super(DecoderUnit, self).__init__()
    self.sDim = sDim
    self.xDim = xDim
    self.yDim = yDim
    self.attDim = attDim
    self.emdDim = attDim

    self.attention_unit = AttentionUnit(sDim, xDim, attDim)
    self.tgt_embedding = nn.Embedding(yDim+1, self.emdDim) # the last is used for <BOS> 
    self.gru = nn.GRU(input_size=xDim+self.emdDim, hidden_size=sDim, batch_first=True)
    self.fc = nn.Linear(sDim, yDim)


  def forward(self, x, sPrev, yPrev):
    # x: feature sequence from the image decoder.
    batch_size, T, _ = x.size()
    alpha = self.attention_unit(x, sPrev)
    context = torch.bmm(alpha.unsqueeze(1), x).squeeze(1)
    yProj = self.tgt_embedding(yPrev.long())
    # self.gru.flatten_parameters()
    output, state = self.gru(torch.cat([yProj, context], 1).unsqueeze(1), sPrev)
    output = output.squeeze(1)

    output = self.fc(output)
    return output, state

4. 收尾

ASTER在一般attention based的encoder-decoder基础上,加上了TPS作为矫正模块,可以部分缓解由于弯曲文字导致的识别不准确问题。后续也有不少论文是沿着这个方向进行改进的,比如说MORAN、ESIR等等。下一篇我会继续沿着识别弯曲文本的方向,介绍利用2d attention进行文字识别的论文SAR.

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值