tensorflow 版本: 1.13.1
参考网页(https://github.com/tensorflow/nmt)中的介绍说明进行实现
tensorflow 实现如下:
class BahdanauAttention(tf.layers.Layer):
def __init__(self, num_units):
super(BahdanauAttention, self).__init__()
self.num_units = num_units
self.w1 = tf.layers.Dense(num_units)
self.w2 = tf.layers.Dense(num_units)
self.v = tf.layers.Dense(1)
def build(self, input_shape):
self.built = True
def call(self, inputs):
# encoder_output.shape: [batch, max_time, hidden_size], decoder_cell_state.shape:[batch, hidden_size]
encoder_output, decoder_cell_state = inputs[0], inputs[1]
decoder_cell_state = tf.expand_dims(decoder_cell_state, 1)
# score.shape: [batch, max_time, 1]
score = self.v(tf.nn.tanh(self.w1(encoder_output) + self.w2(decoder_cell_state)))
attention_weight = tf.nn.softmax(score, axis=1)
# context vector
context_vector = attention_weight * encoder_output
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector