版权声明:博主原创文章,转载请注明来源,谢谢合作!!
https://blog.csdn.net/hl791026701/article/details/84404901
这篇博文主要是介绍基于tensorflow使用google的seq2seq模型来构建一个聊天机器人,主要是学习一下encoder、decoder、attention、bean search等原理和实现方式。
seq2seq是一种很常见的技术。例如,在法语-英语翻译中,预测的当前英语单词不仅取决于所有前面的已翻译的英语单词,还取决于原始的法语输入;另一个例子,对话中当前的response不仅取决于以往的response,还取决于消息的输入。其实,seq2seq最早被用于机器翻译,后来成功扩展到多种自然语言生成任务,如文本摘要和图像标题的生成。本文将介绍几种常见的seq2seq的模型原理,seq2seq的变形以及seq2seq用到的一些小trick。
—、 seq2seq模型简介
seq2seq属于encoder-decoder结构的一种,这里看看常见的encoder-decoder结构,基本思想就是利用两个RNN,一个RNN作为encoder,另一个RNN作为decoder。encoder负责将输入序列压缩成指定长度的向量,这个向量就可以看成是这个序列的语义,这个过程称为编码,如上图,获取语义向量最简单的方式就是直接将最后一个输入的隐状态作为语义向量C。也可以对最后一个隐含状态做一个变换得到语义向量,还可以将输入序列的所有隐含状态做一个变换得到语义变量。而decoder则负责根据语义向量生成指定的序列,这个过程也称为解码,如下图,最简单的方式是将encoder得到的语义变量作为初始状态输入到decoder的rnn中,得到输出序列。可以看到上一时刻的输出会作为当前时刻的输入,而且其中语义向量C只作为初始状态参与运算,后面的运算都与语义向量C无关。
encoder-decoder模型对输入和输出序列的长度没有要求,应用场景也更加广泛。
详情可以参考:seq2seq模型详解
二、数据文本处理
- 构建模型的第一步是进行语料的获取和处理。
这次我们使用的中文电视剧对白语料 https://github.com/fateleak/dgk_lost_conv。
另外博主还搜集了其它市面上已有的开源中文聊天语料并系统化整理工作
wget https://lvzhe.oss-cn-beijing.aliyuncs.com/dgk_shooter_min.conv.zip
下载预料后要用unzip dgk_shooter_min.conv.zip进行解压。输出dgk_shooter_min.conv
我们可以看下原始语料格式
2. 我们要对语料进行简单的清洗处理,然后根据根据’“ / ”进行split得到一个个字。
for line in tqdm(fp):
if line.startswith('M '):
line = line.replace('\n','')
if '/' in line:
line = line[2:].split('/')
else:
line = list(line[2:])
line = line[:-1] #
group.append(list(regular(''.join(line))))
else:
lsat_line=None
if group:
groups.append(group)
group=[]
- 处理完之后我们要自己构造Q、A问答句。从上面语料我们可以看出每段会话由标识符“E”分割,所以我们根据(a1,a2),(a1+a2,a3) ,(a1,a2+a3)这样的组合来构造问答语句:
#假设 a1,a2,a3,三句话 (a1,a2),(a1+a2,a3) ,(a1,a2+a3)
if next_line:
x_data.append(line)
y_data.append(next_line)
if last_line and next_line:
x_data.append(last_line + make_split(last_line) + line)
y_data.append(next_line)
if next_line and next_next_line:
x_data.append(line)
y_data.append(next_line + make_split(next_line) + next_next_line)
构建好输入X、Y即输入的问答后 ,接下来我们要进行序列化处理。
ws_input = WordSequence()
ws_input.fit(x_data + y_data)
- 构建了一个word_sequence类:主要函数的作用分别是创建字典、句子转向量、词向量映射、根据超参定制化训练数据、基础数据标记、初始化词典。
4.1 每个句子特殊处理
(1)在训练过程中,每个batch中句子长度不一样,此时对于短句子用填充
(2)用于句子结尾,告诉decoder停止预测
(3)不在字典中的词用替换
(4) decoder第一个输入,告诉decoder预测开始
def fit(self,sentences,min_count=5,max_count=None,max_features=None):
"""
Args:
min_count 最小出现次数
max_count 最大出现次数
max_features 最大特征数
"""
assert not self.fited , 'WordSequence 只能 fit 一次'
count={}
for sentence in sentences:
arr=list(sentence)
for a in arr:
if a not in count:
count[a]=0
count[a]+=1
print(count)
if min_count is not None:
count={k : v for k,v in count.items() if v >= min_count}
if max_count is not None:
count={k : v for k,v in count.items() if v<=max_features}
self.word_dict = {
WordSequence.PAD_TAG:WordSequence.PAD,
WordSequence.UNK_TAG:WordSequence.UNK,
WordSequence.START_TAG:WordSequence.START,
WordSequence.END_TAG:WordSequence.END
}
if isinstance(max_features,int):