对机器翻译中常用的两种注意力机制进行实现
(1)加性注意力机制 (2) 乘性注意力机制
def attention(hidden,enc_output,W1,W2,W3,V,att_select=None):
'''
input: 隐藏层状态tensor, 编码器每个单元输出tensor, 权重矩阵W1,W2,V,注意力机制选择
'additive': 需要三个权重矩阵, 'multiplicate': 只需要一个权重矩阵 W3
output: 带权重的文本tensor [batch_size,1,hidden_size]
'''
if att_select =='additive':
#hidden = [batch_size,hidden_size] ----> [batch_size,1,hidden_size]
hidden_with_time_axis= tf.expand_dims(hidden,axis=1)
#temp = [batch_size, enc_output.shape[1], hidden_size]
temp= tf.nn.tanh([W1(enc_output) + W2(hidden_with_time_axis)])
# score = [batch_size,enc_output.shape[1],1]
score= V(temp)
attention_weights= tf.nn.softmax(score,axis=1)
#context_vector =[batch_size, enc_output.shape[1], hidden_size]
# ---> [batch_size, hidden_size]
context_vector= attention_weights * enc_output
context_vector = tf.reduce_sum(context_vector,axis=1)
# context_vector = [batch_size, 1,hidden_size]
context_vector = tf.expand_dims(context_vector,axis=1)
elif att_select == 'multiplicate':
#hidden = [batch_size,hidden_size] ----> [batch_size,hidden_size,1]
hidden_with_time_axis = tf.expand_dims(hidden, axis =2)
#temp = [batch_size, enc_output.shape[1], hidden_size]
# ----> [[batch_size,hidden_size,enc_output.shape[1]]
temp = W3(tf.transpose(enc_output,perm=[0,2,1]))
# score = [batch_size,enc_output.shape[1],hidden.shape[1]]
score = temp * hidden_with_time_axis
score = tf.transpose(score,perm=[0,2,1])
score = tf.expand_dims(tf.reduce_sum(score,axis=2),axis=2)
attention_weights = tf.nn.softmax(score,axis =1)
# like above additive
context_vector = attention_weights * enc_output
context_vector = tf.reduce_sum(context_vector,axis=1)
#context_vector = tf.expand_dims(context_vector, axis=1)
return context_vector,attention_weights
对于函数的输入,做一下解释:
hidden shape=[batch_size, 1, hidden_size]
enc_output shape =[batch_size, text_len, hidden_size]
W1 = [hidden_size, hidden_size] --------- tf.keras.layers.Dense 实现
W2 = W1
W3 = [text_len, text_len]
V=[hidden_size,1]
att_select = additive \ multiplicate