上篇文章我们使用tf.contrib.legacy_seq2seq下的API构建了一个简单的chatbot对话系统,但是我们已经说过,这部分代码是1.0版本之前所提供的API,将来会被弃用,而且API接口并不灵活,在实际使用过程中还会存在版本不同导致的各种个样的错误。所以我们有必要学习一下新版本的API,这里先来说一下二者的不同:
- 新版本都是用dynamic_rnn来构造RNN模型,这样就避免了数据长度不同所带来的困扰,不需要再使用model_with_buckets这种方法来构建模型,使得我们数据处理和模型代码都简洁很多
- 新版本将Attention、Decoder等几个主要的功能都分别进行封装,直接调用相应的Wapper函数进行封装即可,调用起来更加灵活方便,而且只需要写几个简单的函数既可以自定义的各个模块以满足我们个性化的需求。
- 实现了beam_search功能,可直接调用
这次我们先来看如何直接使用新版本API构造对话系统,然后等下一篇文章在分析一些主要文件和函数的源码实现。本文代码可以再我的github中找到:seq2seq_chatbot_new。欢迎fork和star~~
数据处理
仍然沿用之前的代码,不过createBatch函数可以变得简单而又整洁,原因是新版本API我们在定义输入的placeholder是不需要在定义为seq_len*batch_size这样的列表,直接定义一个batch_size*seq_len的tensor即可。所以数据处理部分的代码也可以简化为得到一个嵌套列表的形式即可。这里我们重新定义Batch类,使其包含四个元素分别为encoder_inputs、encoder_inputs_length、decoder_targets、decoder_targets_length,前两项是PAD之后的源序列及每个序列的长度,后两项为PAD之后的目的序列和每个序列的长度。这里encoder_inputs_length和decoder_targets_length是为了动态编解码时表示序列长度的作用,下面给出修改了的Batch类和createBatch函数,其他函数都没有发生变化。
class Batch:
def __init__(self):
self.encoder_inputs = [] #嵌套列表,每个元素都是一个句子中每个单词都id
self.encoder_inputs_length = [] #一维列表,每个元素对应上面每个句子的长度
self.decoder_targets = []
self.decoder_targets_length = []
def createBatch(samples):
'''
根据给出的samples(就是一个batch的数据),进行padding并构造成placeholder所需要的数据形式
:param samples: 一个batch的样本数据,列表,每个元素都是[question, answer]的形式,id
:return: 处理完之后可以直接传入feed_dict的数据格式
'''
batch = Batch()
#获取每个样本的长度,并保存在source_length和target_length中
batch.encoder_inputs_length = [len(sample[0]) for sample in samples]
batch.decoder_targets_length = [len(sample[1]) for sample in samples]
#获得一个batch样本中最大的序列长度
max_source_length = max(batch.encoder_inputs_length)
max_target_length = max(batch.decoder_targets_length)
#将每个样本进行PAD至最大长度
for sample in samples:
#将source进行反序并PAD值本batch的最大长度
source = list(reversed(sample[0]))
pad = [padToken] * (max_source_length - len(source))
batch.encoder_inputs.append(pad + source)
#将target进行PAD,并添加END符号
target = sample[1]
pad = [padToken] * (max_target_length - len(target))
batch.decoder_targets.append(target + pad)
return batch
模型构建
这一部分代码主要是从tensorflow官网给出的nmt例子的代码简化而来,实现了最基本的attention和beam_search等功能,同时有将nmt代码中繁杂的代码逻辑进行简化,将不必要的代码都清除,是的代码的易读性提高。这里参考nmt中所提到的构建train、eval、inference三个图进行模型构建,好处在于(下面部分翻译自nmt官方文档):
- inference图往往与train和eval结构存在较大差异(没有decoder输入和目标,需要使用贪婪或者beam_search进行decode,batch_size也不同等等),所以往往需要单独进行构建
- eval图也会得到简化,因为其不需要进行反向传播,只需要得到一个loss和acc值
- 数据可以分别进行feed,简化数据操作
- 变量重用变得简单,因为train、eval存在一些公用变量和代码块,就不需要我们重复定义,使代码简化
- 只需要在train时不断保存模型参数,然后在eval和infer的时候restore参数即可
以上,所以我们构建了train、eval、infer三个函数来实现上面的功能。在看代码之前我们先来简单说一下新版API几个主要的模块以及相互之间的调用关系。tf.contrib.seq2seq文件夹下面主要有下面6个文件,除了loss文件和之前的sequence_loss函数没有很大区别,这里不介绍之外,其他几个文件都会简单的说一下,这里主要介绍函数和类的功能,源码会放在下篇文章中介绍。
- decoder
- basic_decoder
- helper
- attention_wrapper
- beam_search_decoder
- loss
BasicDecoder类和dynamic_decode
decoder和basic_decoder文件可以放在一起看,decoder文件中定义了Decoder抽象类和dynamic_decode函数,dynamic_decode可以视为整个解码过程的入口,需要传入的参数就是Decoder的一个实例,他会动态的调用Decoder的step函数按步执行decode,可以理解为Decoder类定义了单步解码(根据输入求出输出,并将该输出当做下一时刻输入),而dynamic_decode则会调用control_flow_ops.while_loop这个函数来循环执行直到输出结束编码过程。而basic_decoder文件定义了一个基本的Decoder类实例BasicDecoder,看一下其初始化函数:
def __init__(self, cell, helper, initial_state, output_layer=None):
需要传入的参数就是cell类型、helper类型、初始化状态(encoder的最后一个隐层状态)、输出层(输出映射层,将rnn_size转化为vocab_size维),需要注意的就是前面两个,下面分别介绍:
cell类型(Attention类型)
cell类型就是RNNCell,也就是decode阶段的神经元,可以使简单的RNN、GRU、LSTM(也可以加上dropout、并使用MultiRNNCell进行堆叠成多层),也可以是加上了Attention功能之后的RNNcell。这就引入了attention_wrapper文件中定义的几种attention机制(BahdanauAttention、 LuongAttention、 BahdanauMonotonicAttention、 LuongMonotonicAttention)和将attention机制封装到RNNCell上面的方法AttentionWrapper。其实很简单,就跟dropoutwrapper、outputwrapper一样,我们只需要在原本RNNCell的基础上在封装一层attention即可。代码如下所示:
# 分为三步,第一步是定义attention机制,第二步是定义要是用的基础的RNNCell,第三步是使用AttentionWrapper进行封装
#定义要使用的attention机制。
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)
#attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)
# 定义decoder阶段要是用的LSTMCell,然后为其封装attention wrapper
decoder_cell = self._create_rnn_cell()
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=self.rnn_size, name='Attention_Wrapper')
helper类型
helper其实就是decode阶段如何根据预测结果得到下一时刻的输入,比如训练过程中应该直接使用上一时刻的真实值作为下一时刻输入,预测过程中可以使用贪婪的方法选择概率最大的那个值作为下一时刻等等。所以Helper也就可以大致分为训练时helper和预测时helper两种。官网给出了下面几种Helper类:
- “Helper”:最基本的抽象类
- “TrainingHelper”:训练过程中最常使用的Helper,下一时刻输入就是上一时刻target的真实值</