attention其实很简单,比如有翻译:
我喜欢游泳->I like swimming
那么在翻译的时候可以这样,也就是越靠近相对应的词,我越注意,影响也就越大
i = f(0.7(“我”),0.2(“喜欢”)+0.1(“游泳”))
like = f(0.2(“我”),0.6(“喜欢”)+0.2(“游泳”))
swimming = f(0.1(“我”),0.2(“喜欢”)+0.7(“游泳”))
tensorflow又两种attention机制,分别为Bahdanau attention,和LuongAttention
公式如下:
代码解释如下:
attention部分
FC = 全连接层dense,期中神经元个数为decode的神经元个数dec_units
EO = Encoder output,(batch_size,states,enc_hidden)
H = decoder hidden state#(batch_size,state,dec_hidden),decoder初始化的时候为enc_state
X = input to the decoder
score = FC(tanh(FC(EO) + FC(H)))#output shape (batch_size,states, 1)
attention weights = softmax(score, axis = 1)#output shape (batch_size,states, 1),这两步其实就是对EO,H进行加权计算,然后得出每一个encode state(每一个输入的单词)对decoder(输出单词)的影响程度
context vector = sum(attention weights * EO, axis = 1)#output shape (batch_size, enc_hidden),其中attention weights * EO的意思就是将states影响系数,比如state10.6,表示输入的第一个单词只有0.6的内容影响当前的输出单词。计算sum后的意思为所有state在算上影响系数以后,总的对当前输出的影响
embedding output = embedding(X)#output shape (batch_size, 1, embedding_dim)
merged vector = concat(embedding output, context vector)#output shape (batch_size, 1, embedding_dim + hidden_size)将embedding output和context vector拼接在一起
tensorflow官网keras实现:
https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = gru(self.dec_units)
self.fc = tf.keras.layers.Dense(vocab_size)
# used for attention
self.W1 = tf.keras.layers.Dense(self.dec_units)
self.W2 = tf.keras.layers.Dense(self.dec_units)
self.V = tf.keras.layers.Dense(1)
def call(self, x, hidden, enc_output):
# enc_output shape == (batch_size, max_length, hidden_size)
# hidden shape == (batch_size, hidden size)
# hidden_with_time_axis shape == (batch_size, 1, hidden size)
# we are doing this to perform addition to calculate the score
hidden_with_time_axis = tf.expand_dims(hidden, 1)
# score shape == (batch_size, max_length, 1)
# we get 1 at the last axis because we are applying tanh(FC(EO) + FC(H)) to self.V
score = self.V(tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)))
# attention_weights shape == (batch_size, max_length, 1)
attention_weights = tf.nn.softmax(score, axis=1)
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = attention_weights * enc_output
context_vector = tf.reduce_sum(context_vector, axis=1)
# x shape after passing through embedding == (batch_size, 1, embedding_dim)
x = self.embedding(x)
# x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
# passing the concatenated vector to the GRU
output, state = self.gru(x)
# output shape == (batch_size * 1, hidden_size)
output = tf.reshape(output, (-1, output.shape[2]))
# output shape == (batch_size * 1, vocab)
x = self.fc(output)
return x, state, attention_weights
def initialize_hidden_state(self):
return tf.zeros((self.batch_sz, self.dec_units))
训练:
for epoch in range(EPOCHS):
start = time.time()
hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in enumerate(dataset):
loss = 0
with tf.GradientTape() as tape:
enc_output, enc_hidden = encoder(inp, hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * BATCH_SIZE, 1)
# Teacher forcing - feeding the target as the next input
for t in range(1, targ.shape[1]):
# passing enc_output to the decoder
predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
loss += loss_function(targ[:, t], predictions)
# using teacher forcing
dec_input = tf.expand_dims(targ[:, t], 1)
batch_loss = (loss / int(targ.shape[1]))
total_loss += batch_loss
variables = encoder.variables + decoder.variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch,
batch_loss.numpy()))