“TensorFlow中AttentionCellWrapper的attn_length是什么鬼?attention window又是什么鬼?”--初次使用AttentionCellWrapper做attentionRNN时可能会有些懵逼...
因为新版TensorFlow的API doc中关于AttentionCellWrapper的介绍有些含糊,它说是基于 0473这篇经典的attention paper 实现的,但其实不太match,0473中描述的是基于encoder-decoder的attention,而这却是个用于非encoder-decoder结构的attention机制,形式上、结构上有很大区别,所以看完0473后使用AttentionCellWrapper或看它的实现时会产生很多疑问,比如attn_length等参数并不能在paper中找到参考。
所以就有了本文、分享如下。
要想了解AttentionRNN、TF中的AttentionCellWrapper建议看:
- 旧版TensorFlow doc提到的 06733这篇paper ,其中“Machine Reader”思路比较接近,但也不完全相同
- 这篇lookback-rnn-attention-rnn,里面“... look at the outputs from the last n steps when generating the output for the current step....”,这里的”last n“对应的就是AttentionCellWrapper中的attn_length
AttentionRNN简单总结如下
设计初衷
总之就是改进普通RNN系列的不足,主要有:
1,LSTM/GRU等RNN系列的cell是按马尔科夫方式计算state的,即计算state[t]时仅使用了state[t-1],虽然state[t-1]包含了一些之前的信息,但是当sequence比较长或memory size不够大时,就会出现信息丢失(毕竟实际中cell state容量有限)。
2,LSTM等的输入是token-by-token的、即按序处理,它不能抓住token之间的的结构信息、不能建模token之间的关系。
......
简介
思路简单易懂,算是encoder-decoder中attention的简化版吧。但paper 06733与TF的实现在细节上有所不同,不赘述了,只说一下AttentionCellWrapper实现的大体思路吧:
1,首先,context不再单纯的使用cell state(劣势上文已说),而是将t-1、t-2...t- attn_length 时刻的state attention一把、加权累加得到
- attn_length:指定考虑最近几个state、即attention的window size
- TF实现中使用的卷积等价于遍历各state分别乘参数w
2,然后,在t时刻,将x[t]和之前的context[t-1]拼接变换成新的x[t]作为输入,通过RNN/LSTM等得到当前时刻的state[t],将之前的state attention得到当前时刻的context[t],在时间维上递归下去...哎呀,说起来好麻烦,直接上公式吧、一目了然、盗图如下:
通过这种方式能找到token[t] 与token[t-1] token[t-2]...token[t-attn_length]之间的关系,可视化效果--借用paper 06733中的图:
Demo
用法非常简单,对RNN cell wrap一下就好了,类似于DropoutWrapper,如下:
......
fw_cells, bw_cells = [], []
for _ in range(self.layer_num):
fw_cell = tf.nn.rnn_cell.LSTMCell(hidden_size, forget_bias=2.0)
fw_cell = tf.contrib.rnn.AttentionCellWrapper(fw_cell, attn_length, state_is_tuple=True)
fw_cells.append(fw_cell)
......
总结
适合大数据、长序列(毕竟优化了long dependencies)
训练耗时增加很多