attention
核心思想
attention机制有三个重要组成部分: Q Q Q -> 输入 q u e r y query query, K K K-> 系数 k e y key key, V V V-> 知识库取值 v a l u e value value。具体目的为,把一个 q u e r y query query翻译成 v a l u e value value的组合结果,过程中用到系数 k e y key key进行加权,示意图如下所示:
下图为细节展开,可以看出 q u e r y query query通过 k e y key key和变换函数 F F F后得到score(s),经softmax归一化后得到系数 a a a,之后再将 a a a与 v a l u e value value加权得到最终的attention value。
优缺点
优点
- 一步到位的全局联系捕捉
- 并行计算减少模型训练时间
- 模型复杂度小,参数少
缺点
不能捕捉语序顺序的相关信息,因为其本身是一个词袋模型。
算法实现
import tensorflow as tf
def attention(Q, K, scaled_=True):
""" attention implementation
:param Q:
:param K:
:param scaled_: whether scaling logit by sqrt{dim of K}
:return: attention weight
"""
logit = tf.matmul(Q, K, transpose_b=True) # [batch_size, sequence_length, sequence_length]
if scaled_:
d_k = tf.cast(tf.shape(K)[-1], dtype=tf.float32)
logit = tf.divide(logit, tf.sqrt(d_k)) # [batch_size, sequence_length, sequence_length]
weight = tf.nn.softmax(logit, dim=-1) # [batch_size, sequence_length, sequence_length]
return weight
self-attention
核心思想
self-attention借助attention机制,计算每个单词与其他所有单词的关联,例如在翻译(I am on the bank of the river)的任务里,当遇到bank时,river就有较高的attention-score。利用这些attention-score就能得到一种加权表示,然后放到一个forward-network中得到新的表示,这一表示会考虑到上下文信息。
借用知乎大神的思路,我们的任务是得到"thinking"和"machines"两个单词的self-attention取值。第一步获取这两个单词的embeddding x 1 x_1 x1和 x 2 x_2 x2,对于某一个 x x x,分别与 W Q W^Q WQ、 W K W^K WK、 W V W^V WV相乘得到三个矩阵 Q , K , V Q, K, V Q,K,V,示意图如下所示。
之后经过一系列非线性变换得到最终的 z z z,示意图如下所示。
上述非线性变换的核心过程如下所示。
算法实现
def self_attention(data, **config):
""" self_attention implementation
:param data: input data
:param config: param container
:return: self attention weight
"""
Q = tf.layers.dense(data, config['hidden_dim']) # [batch_size, sequence_length, hidden_dim]
K = tf.layers.dense(data, config['hidden_dim']) # [batch_size, sequence_length, hidden_dim]
V = tf.layers.dense(data, config['n_classes']) # [batch_size, sequence_length, n_classes]
weight = attention(Q, K) # [batch_size, sequence_length, sequence_length]
s_attn = tf.matmul(weight, V) # [batch_size, sequence_length, n_classes]
return s_attn
multi-head attention
multi-head attention核心思想为一系列attention的叠加和拼接,示意图如下所示。