最近在做基于attention的唇语识别,无奈网上关于tf中attention的具体实现没有较好的Demo,且版本大多不一致,琐碎而且凌乱,不得不自己翻开源码,阅读一番,收获颇丰,现分享与此。
PS:本文基于tensorflow-gpu-1.4.0版本,阅读前,读者最好对Attention mechanism有一定的了解,不然可能会一头雾水。
tf-1.4.0中,关于attention机制的代码位于tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py文件中,使用时如下:
from tensorflow.contrib.seq2seq.python.ops import *
该py文件主要包含4大块:
Attention mechanism:用来实现计算不同类型的attention vector(即context加权和后的向量),包括:
- _BaseAttentionMechanism类:所有Attention的基类,
- BahdanauAttention:论文https://arxiv.org/abs/1409.0473中的实现:
ut=vTtanh(W1h+W2dt)at=softmax(ut)ct=∑lLatlhl u t = v T t a n h ( W 1 h + W 2 d t ) a t = s o f t m a x ( u t ) c t = ∑ l L a l t h l - LuongAttention:论文https://arxiv.org/abs/1508.04025中的实现
ut=dtW1hat=softmax(ut)ct=∑lLatlhl u t = d t W 1 h a t = s o f t m a x ( u t ) c t = ∑ l L a l t h l - 等等,这里只以这两个为例。
- BahdanauAttention:论文https://arxiv.org/abs/1409.0473中的实现:
- _BaseAttentionMechanism类:所有Attention的基类,
AttentionWrapperState类:用来存储整个计算过程中的state,类似rnn中的state(LSTMStateTuple),只不过这里还额外存储了attention,time等信息。
- AttentionWrapper类:用来组件rnn cell和上述的所有类的实例,从而构建一个带有attention机制的Decoder
- 一些公用方法,包括求解attention权重的softmax的替代函数hardmax,不同Attention的权重计算函数
对整个结构有了大致了解后,我们来看每个类的参数,是如何对应到计算过程的,这里以BahdanauAttention为例:
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
normalize=False,
probability_fn=None,
score_mask_value=None,
dtype=None,
name="BahdanauAttention"):
num_units
:官方解释是The depth of the query mechanism。
我们看公式
ut=vTtanh(W1h+W2dt)
u
t
=
v
T
t
a
n
h
(
W
1
h
+
W
2
d
t
)
,这里的num_units即为矩阵
w1
w
1
列数,也即我们的context信息
h
h
经过上述的矩阵乘法之后的维度,我们可以看到,这个num_units用来初始化了一个memory_layer,这个memory_layer其实就是一个全连接,用来完成上式中的第一个矩阵乘法,因为在解码过程中是不变的,所以这个矩阵乘法在初始化时就已经完成。
super(BahdanauAttention, self).__init__(
query_layer=layers_core.Dense(
num_units, name="query_layer", use_bias=False, dtype=dtype),
memory_layer=layers_core.Dense(
num_units, name="memory_layer", use_bias=False, dtype=dtype),
memory=memory,
probability_fn=wrapped_probability_fn,
memory_sequence_length=memory_sequence_length,
score_mask_value=score_mask_value,
name=name)
可以看到,num_units同样初始化了一个query_layer,它是用来完成上式中的第二个矩阵乘法,为了保证两个矩阵乘法的结果维度一致,因此两者都用num_units进行初始化,由于第二个矩阵乘法是跟每个时间步相关的,因此是在对应解码的时间步完成的。
memory
:官方解释是The memory to query; usually the output of an RNN encoder。即我们解码中要用到的context信息,维度为[batch_size, time_step, context_dim]。
回到前面,在其调用父类BaseAttentionMechanism的__init__方法时,完成了上述的第一个矩阵乘法,:
with ops.name_scope(
name, "BaseAttentionMechanismInit", nest.flatten(memory)):
self._values = _prepare_memory(
memory, memory_sequence_length,
check_inner_dims_defined=check_inner_dims_defined)
self._keys = (
self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable
else self._values)
可以看到,第一个矩阵乘法的结果保存在_keys里面,也即上述的由num_units初始化的全连接层与处理后的memory(_values)相乘的结果。
PS:在LuongAttention中这个query_layer是None,回看
ut=dtW1h
u
t
=
d
t
W
1
h
就很清楚了,它用memory_layer在初始化的时候完成了
W1h
W
1
h
,同样保存为_keys用于解码,每个时间步只需用对应的
dt
d
t
与_keys相乘即可。
memory_sequence_length
:类似tf中动态rnn的sequence_length,用来告诉那些文本信息是需要参与attention计算的,不参与计算的全部置0,默认是所有的memory都参与计算。由上述求得_values的_prepare_memory()方法完成这个工作。
normalize
:官方解释是Whether to normalize the energy term。即是否实现论文 https://arxiv.org/abs/1602.07868中的标准化。
probability_fn
:计算归一化权重的方式,默认是softmax,也可以是诸如hardmax的方法等等,该方法必须接受一个未归一化的权重score,返回对应的规范化的概率(即要求和为1)
现在,我们有了搭建好的RNN(tf中的RNNcell,可以是多层也可以单层),选择好的attention mechanism,接下来就要把它们拼在一起,这个工作则是由AttentionWrapper完成的,先来看参数:
def __init__(self,
cell,
attention_mechanism,
attention_layer_size=None,
alignment_history=False,
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
name=None):
cell
:rnn cell的实例,可以是单个cell,也可以是多个cell stack后的mutli layer rnn。
attention_mechanism
:前面提到的任意attention类的实例,或者实例列表,用来求解多个attention权重。
attention_layer_size
:官方的解释很长,意思就是说,它是用来控制我们最后生成的attention是怎么得来的,如果是None,则直接返回对应attention mechanism计算得到的加权和向量;如果不是None,则在调用_compute_attention方法时,得到的加权和向量还会与output进行concat,然后再经过一个线性映射,变成维度为attention_layer_size的向量:
if attention_layer_size is not None:
...
self._attention_layers = tuple(
layers_core.Dense(
attention_layer_size,
name="attention_layer",
use_bias=False,
dtype=attention_mechanisms[i].dtype)
for i, attention_layer_size in enumerate(attention_layer_sizes))
def _compute_attention(attention_mechanism, cell_output, attention_state,attention_layer):
...
if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
else:
attention = context
return attention, alignments, next_attention_state
PS:一般得到的attention是指这里的attention_layer_size=None的情况,但是之后解码的过程中,我们会计算 ht^=tanh(W[att;dt]) h t ^ = t a n h ( W [ a t t ; d t ] ) ,因此如果这里设置不为None,其实就是帮我们完成了后面一步的线性映射,我们拿到返回的att之后,直接做一个 tanh t a n h 就好了。
alignment_history
:是否将之前每一步的alignment存储在state中,主要用于后期的可视化,关注attention的关注点。
cell_input_fn
:input送入cell的方式,默认是会将input和上一步的attention拼接起来送入rnn cell:
if cell_input_fn is None:
cell_input_fn = (
lambda inputs, attention: array_ops.concat([inputs, attention], -1))
这个方法会在每个时间步的开始,送入rnn cell的时候调用:
def call(self, inputs, state):
...
cell_inputs = self._cell_input_fn(inputs, state.attention)
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
如何不想拼接,直接传入:
cell_input_fn=lambda input, attention: input
output_attention
:是否返回attention,如果为False则直接返回rnn cell的输出,注意,无论是否为True,每一个时间步的attention都会存储在这里的state(AttentionWrapperState的一个实例)中
def call(self, inputs, state):
...
cell_inputs = self._cell_input_fn(inputs, state.attention)
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
...
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
attention_state=self._item_or_tuple(all_attention_states),
alignments=self._item_or_tuple(all_alignments),
alignment_history=self._item_or_tuple(maybe_all_histories))
if self._output_attention:
return attention, next_state
else:
return cell_output, next_state
一个简单的应用实例:
cell=[rnn.LSTMCell(cell_size) for i in range(num_layers)]
mutli_layer = rnn.MultiRNNCell(cell)
attention_mechanism = LuongAttention(num_units =num_units,
memory=context)
att_wrapper = AttentionWrapper(cell=mutil_layer,
attention_mechanism=attention_mechanism,
attention_layer_size=att_size,
cell_input_fn=lambda input, attention: input)
states = att_wrapper.zeros_state(batch_size, tf.float32)
with tf.variable_scope(SCOPE, reuse=tf.AUTO_REUSE):
for i in range(decode_time_step):
h_bar_without_tanh, states=att_wrapper(_X, states)
h_bar = tf.tanh(h_bar_without_tanh)
_X = tf.nn.softmax(tf.matmul(h_bar, W), 1)