Attention模块需要使用keras的自定义写法
简要的说Attention模块时将n个时刻的LSTM输出结合算出一个向量输入到下一个RNN中
自己之前在看恩达的课程的时候,画了张图
class AttentionLayer(Layer):
def __init__(self, **kwargs):
self.init = initializations.get('normal')
super(AttentionLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.W = self.init((input_shape[-1],))
self.trainable_weights = [self.W]
super(AttLayer, self).build(input_shape)
def call(self, x, mask=None):
e = K.tanh(K.dot(x, self.W)
ai = K.exp(e)
weights = ai/K.sum(ai, axis=1).dimshuffle(0,'x')
weighted_input = x*weights.dimshuffle(0,1,'x')
return weighted_input.sum(axis=1)
def get_output_shape_for(self, input_shape):
return (input_shape[0], input_shape[-1])
# 基本使用和其他的layer一致
l_lstm = Bidirectional(LSTM(100, return_sequences=True))(embedded_seq)
attenion= AttentionLayer()(l_lstm)