attention_ocr源码

  1. 主要看sequence_layers.py这个脚本中才是实现了attention+decoder的部分,model中只是个架子。
  2. sequence_layer中也是直接调用了TF的api,如果想深入理解,还是需要看TF源码。先从sequence_layer入手。

AttentionWithAutoregression继承了Attention,Attention继承了SequenceLayerBase。其中create_logits调用的就是SequenceLayerBase中的,create_logits中的unroll_cell调用的是Attention里面的。

def create_logits(self):
    """Creates character sequence logits for a net specified in the constructor.
    A "main" method for the sequence layer which glues together all pieces.
    Returns:
      A tensor with shape [batch_size, seq_length, num_char_classes].
    """
    with tf.variable_scope('LSTM'):
      first_label = self.get_input(prev=None, i=0)
      decoder_inputs = [first_label] + [None] * (self._params.seq_length - 1)
        # 是全0矩阵吗?
      lstm_cell = tf.contrib.rnn.LSTMCell(
          self._mparams.num_lstm_units,
            # 就是输出的size
          use_peepholes=False,
            # 采用的是最早提出的LSTM的构造,1997年
          cell_clip=self._mparams.lstm_state_clip_value,
          state_is_tuple=True,
          initializer=orthogonal_initializer)
        # 构建了decoder的LSTM结构
      lstm_outputs, _ = self.unroll_cell(
          decoder_inputs=decoder_inputs,
          initial_state=lstm_cell.zero_state(self._batch_size, tf.float32),
          loop_function=self.get_input,
          cell=lstm_cell)
        # 调用了ATTENTION方法中的
    with tf.variable_scope('logits'):
      logits_list = [
          tf.expand_dims(self.char_logit(logit, i), dim=1)
          for i, logit in enumerate(lstm_outputs)
      ]

    return tf.concat(logits_list, 1)


class ATTENTION():

  def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
    return tf.contrib.legacy_seq2seq.attention_decoder(
        decoder_inputs=decoder_inputs,
        initial_state=initial_state,
        # 这两个参数的意义还咩有理解透彻
        attention_states=self._net,
        # _net中是CNN输出后和空间one-hot特征concat的特征
        cell=cell,
        loop_function=self.get_input)

https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/sequence_layers.py

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------

tf.contrib.legacy_seq2seq.attention_decoder源码解读:

如果是想理解attention的机制以及实现,还是需要读到TF源码中的,毕竟到这一部分的代码还是没有涉及到真正的模型实现。

我这里看源码是TF1.10的,这是因为attention_ocr源码中用的就是这个。其实是可以看更新一点的源码的。

 

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值