详解tensorflow中的Attention机制

  最近在做基于attention的唇语识别,无奈网上关于tf中attention的具体实现没有较好的Demo,且版本大多不一致,琐碎而且凌乱,不得不自己翻开源码,阅读一番,收获颇丰,现分享与此。
  PS:本文基于tensorflow-gpu-1.4.0版本,阅读前,读者最好对Attention mechanism有一定的了解,不然可能会一头雾水。
  tf-1.4.0中,关于attention机制的代码位于tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py文件中,使用时如下:

from tensorflow.contrib.seq2seq.python.ops import *

  该py文件主要包含4大块:

  • Attention mechanism:用来实现计算不同类型的attention vector(即context加权和后的向量),包括:

    • _BaseAttentionMechanism类:所有Attention的基类,
      • BahdanauAttention:论文https://arxiv.org/abs/1409.0473中的实现:
        ut=vTtanh(W1h+W2dt)at=softmax(ut)ct=lLatlhl u t = v T t a n h ( W 1 h + W 2 d t ) a t = s o f t m a x ( u t ) c t = ∑ l L a l t h l
      • LuongAttention:论文https://arxiv.org/abs/1508.04025中的实现
        ut=dtW1hat=softmax(ut)ct=lLatlhl u t = d t W 1 h a t = s o f t m a x ( u t ) c t = ∑ l L a l t h l
      • 等等,这里只以这两个为例。
  • AttentionWrapperState类:用来存储整个计算过程中的state,类似rnn中的state(LSTMStateTuple),只不过这里还额外存储了attention,time等信息。

  • AttentionWrapper类:用来组件rnn cell和上述的所有类的实例,从而构建一个带有attention机制的Decoder
  • 一些公用方法,包括求解attention权重的softmax的替代函数hardmax,不同Attention的权重计算函数

  对整个结构有了大致了解后,我们来看每个类的参数,是如何对应到计算过程的,这里以BahdanauAttention为例:

def __init__(self,
               num_units,
               memory,
               memory_sequence_length=None,
               normalize=False,
               probability_fn=None,
               score_mask_value=None,
               dtype=None,
name="BahdanauAttention"):

num_units:官方解释是The depth of the query mechanism。
  我们看公式 ut=vTtanh(W1h+W2dt) u t = v T t a n h ( W 1 h + W 2 d t ) ,这里的num_units即为矩阵 w1 w 1 列数,也即我们的context信息 h h 经过上述的矩阵乘法之后的维度,我们可以看到,这个num_units用来初始化了一个memory_layer,这个memory_layer其实就是一个全连接,用来完成上式中的第一个矩阵乘法,因为在解码过程中h是不变的,所以这个矩阵乘法在初始化时就已经完成。

 super(BahdanauAttention, self).__init__(
        query_layer=layers_core.Dense(
            num_units, name="query_layer", use_bias=False, dtype=dtype),
        memory_layer=layers_core.Dense(
            num_units, name="memory_layer", use_bias=False, dtype=dtype),
        memory=memory,
        probability_fn=wrapped_probability_fn,
        memory_sequence_length=memory_sequence_length,
        score_mask_value=score_mask_value,
name=name)

  可以看到,num_units同样初始化了一个query_layer,它是用来完成上式中的第二个矩阵乘法,为了保证两个矩阵乘法的结果维度一致,因此两者都用num_units进行初始化,由于第二个矩阵乘法是跟每个时间步相关的,因此是在对应解码的时间步完成的。

memory:官方解释是The memory to query; usually the output of an RNN encoder。即我们解码中要用到的context信息,维度为[batch_size, time_step, context_dim]。
  回到前面,在其调用父类BaseAttentionMechanism的__init__方法时,完成了上述的第一个矩阵乘法,:

 with ops.name_scope(
        name, "BaseAttentionMechanismInit", nest.flatten(memory)):
      self._values = _prepare_memory(
          memory, memory_sequence_length,
          check_inner_dims_defined=check_inner_dims_defined)
      self._keys = (
          self.memory_layer(self._values) if self.memory_layer  # pylint: disable=not-callable
else self._values)

  可以看到,第一个矩阵乘法的结果保存在_keys里面,也即上述的由num_units初始化的全连接层与处理后的memory(_values)相乘的结果。

PS:在LuongAttention中这个query_layer是None,回看 ut=dtW1h u t = d t W 1 h 就很清楚了,它用memory_layer在初始化的时候完成了 W1h W 1 h ,同样保存为_keys用于解码,每个时间步只需用对应的 dt d t 与_keys相乘即可。

memory_sequence_length:类似tf中动态rnn的sequence_length,用来告诉那些文本信息是需要参与attention计算的,不参与计算的全部置0,默认是所有的memory都参与计算。由上述求得_values的_prepare_memory()方法完成这个工作。

normalize:官方解释是Whether to normalize the energy term。即是否实现论文 https://arxiv.org/abs/1602.07868中的标准化。

probability_fn:计算归一化权重的方式,默认是softmax,也可以是诸如hardmax的方法等等,该方法必须接受一个未归一化的权重score,返回对应的规范化的概率(即要求和为1)

  现在,我们有了搭建好的RNN(tf中的RNNcell,可以是多层也可以单层),选择好的attention mechanism,接下来就要把它们拼在一起,这个工作则是由AttentionWrapper完成的,先来看参数:

def __init__(self,
               cell,
               attention_mechanism,
               attention_layer_size=None,
               alignment_history=False,
               cell_input_fn=None,
               output_attention=True,
               initial_cell_state=None,
name=None):

cell:rnn cell的实例,可以是单个cell,也可以是多个cell stack后的mutli layer rnn。

attention_mechanism:前面提到的任意attention类的实例,或者实例列表,用来求解多个attention权重。

attention_layer_size:官方的解释很长,意思就是说,它是用来控制我们最后生成的attention是怎么得来的,如果是None,则直接返回对应attention mechanism计算得到的加权和向量;如果不是None,则在调用_compute_attention方法时,得到的加权和向量还会与output进行concat,然后再经过一个线性映射,变成维度为attention_layer_size的向量:

  if attention_layer_size is not None:
     ...
      self._attention_layers = tuple(
          layers_core.Dense(
              attention_layer_size,
              name="attention_layer",
              use_bias=False,
              dtype=attention_mechanisms[i].dtype)
for i, attention_layer_size in enumerate(attention_layer_sizes))
def _compute_attention(attention_mechanism, cell_output, attention_state,attention_layer):
  ...
  if attention_layer is not None:
    attention = attention_layer(array_ops.concat([cell_output, context], 1))
  else:
    attention = context

return attention, alignments, next_attention_state

PS:一般得到的attention是指这里的attention_layer_size=None的情况,但是之后解码的过程中,我们会计算 ht^=tanh(W[att;dt]) h t ^ = t a n h ( W [ a t t ; d t ] ) ,因此如果这里设置不为None,其实就是帮我们完成了后面一步的线性映射,我们拿到返回的att之后,直接做一个 tanh t a n h 就好了。

alignment_history:是否将之前每一步的alignment存储在state中,主要用于后期的可视化,关注attention的关注点。

cell_input_fn:input送入cell的方式,默认是会将input和上一步的attention拼接起来送入rnn cell:

if cell_input_fn is None:
      cell_input_fn = (
lambda inputs, attention: array_ops.concat([inputs, attention], -1))

  这个方法会在每个时间步的开始,送入rnn cell的时候调用:

def call(self, inputs, state):
    ...
    cell_inputs = self._cell_input_fn(inputs, state.attention)
    cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

  如何不想拼接,直接传入:

cell_input_fn=lambda input, attention: input

output_attention:是否返回attention,如果为False则直接返回rnn cell的输出,注意,无论是否为True,每一个时间步的attention都会存储在这里的state(AttentionWrapperState的一个实例)中

def call(self, inputs, state):

    ...
    cell_inputs = self._cell_input_fn(inputs, state.attention)
    cell_state = state.cell_state
    cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
    ...
    next_state = AttentionWrapperState(
    time=state.time + 1,
    cell_state=next_cell_state,
    attention=attention,
    attention_state=self._item_or_tuple(all_attention_states),
    alignments=self._item_or_tuple(all_alignments),
    alignment_history=self._item_or_tuple(maybe_all_histories))

    if self._output_attention:
        return attention, next_state
    else:
        return cell_output, next_state

一个简单的应用实例:

    cell=[rnn.LSTMCell(cell_size) for i in range(num_layers)]
    mutli_layer = rnn.MultiRNNCell(cell)
    attention_mechanism = LuongAttention(num_units =num_units,
                            memory=context)
    att_wrapper = AttentionWrapper(cell=mutil_layer,
                        attention_mechanism=attention_mechanism,
                        attention_layer_size=att_size,
                        cell_input_fn=lambda input, attention: input)
    states = att_wrapper.zeros_state(batch_size, tf.float32)
    with tf.variable_scope(SCOPE, reuse=tf.AUTO_REUSE):
        for i in range(decode_time_step):
            h_bar_without_tanh, states=att_wrapper(_X, states)
            h_bar = tf.tanh(h_bar_without_tanh)
            _X = tf.nn.softmax(tf.matmul(h_bar, W), 1)
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值