本文主要是利用Tensorflow中keras框架记录简单实现seq2seq+Attention模型的过程,seq2seq的应用主要有问答系统、人机对话、机器翻译等。代码中会用一个中文对话数据简单测试。
seq2seq模型介绍
seq2seq模型主要有两个部分Encoder和Decoder,Encoder负责将输入序列编码,Decoder负责解码输出序列。最简单的seq2seq模型图:
基于注意力机制的seq2seq模型。
Keras实现seq2seq+Atttention模型
本文的实现是基于Tensorflow 2.0中的keras,也可以用原始的keras也可以,如果用原始的keras,需要自己实现Attention层。
详细代码和数据:https://github.com/huanghao128/zh-nlp-demo
Encoder部分
encoder部分就是一个标准的RNN/LSTM模型,取最后时刻的隐藏层作为输出。我们用tensorflow.keras.models定义Encoder为一个sub model。
先导入tensorflow.keras的常用包。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import activations
from tensorflow.keras.layers import Layer, Input, Embedding, LSTM, Dense, Attention
from tensorflow.keras.models import Model
encoder部分结构,主要就是一个Embedding层,加上LSTM层。
class Encoder(keras.Model):
def __init__(self, vocab_size, embedding_dim, hidden_units):
super(Encoder, self).__init__()
# Embedding Layer
self.embedding = Embedding(vocab_size, embedding_dim, mask_zero=True)
# Encode LSTM Layer
self.encoder_lstm = LSTM(hidden_units, return_sequences=True, return_state=True, name="encode_lstm")
def call(self, inputs):
encoder_embed = self.embedding(inputs)
encoder_outputs, state_h, state_c = self.encoder_lstm(encoder_embed)
return encoder_outputs, state_h, state_c
Decoder部分
decoder部分结构,有三部分输入,一是encoder部分的每个时刻输出,二是encoder的隐藏状态输出,三是decoder的目标输入。另外decoder还包含一个Attention层,计算decoder每个输入与encoder的注意力。
class Decoder(keras.Model):
def __init__(self, vocab_size, embedding_dim, hidden_units):
super(Decoder, self).__init__()
# Embedding Layer
self.embedding = Embedding(vocab_size, embedding_dim, mask_zero=True)
# Decode LSTM Layer
self.decoder_lstm = LSTM(hidden_units, return_sequences=True, return_state=True, name="decode_lstm")
# Attention Layer
self.attention = Attention()
def call(self, enc_outputs, dec_inputs, states_inputs):
decoder_embed = self.embedding(dec_inputs)
dec_outputs, dec_state_h, dec_state_c = self.decoder_lstm(decoder_embed, initial_state=states_inputs)
attention_output = self.attention([dec_outputs, enc_outputs])
return attention_output, dec_state_h, dec_state_c
Encoder和Decoder合并
encoder和decoder模块合并,组成一个完整的seq2seq模型。
def Seq2Seq(maxlen, embedding_dim, hidden_units, vocab_size):
"""
seq2seq model
"""</