QizNLP介绍:基于tensorflow(1.x)的NLP框架,提供NLP多种任务(分类、匹配、序列标注、生成等)代码模板,包括数据处理、模型训练、部署推断的全流程,同时内置一些常见模型提供调用,并支持基于horovod的数据并行式分布训练。
Qznan/QizNLP https://github.com/Qznan/QizNLP
前言
深度学习中的单轮闲聊机器人(single-turn chitchat-bot),通常采用与机器翻译相同的处理范式,即序列到序列模型(seq2seq)。本文将介绍在QizNLP框架中如何利用经典的Transformer模型训练一个闲聊机器人。(想看效果可直接翻到文末截图)
训练数据
训练数据采用清华组发布的LCCC闲聊语料(github地址),该语料包括base&large版本,其中base版本使用了更严格的过滤规则,所以更干净,故本次实验采用base版本。
训练过程
利用QizNLP中的seq2seq模型训练代码模板来进行基于Transformer的闲聊机器人训练。
-
步骤一:安装QizNLP:
pip install QizNLP
-
步骤二:在新建项目地址中初始化:
mkdir ~/myproject && cd ~/myproject
qiznlp_init
# 执行完毕后会看到myproject中已生成run、model、data等目录
-
步骤三:下载LCCC-base-split.zip语料解压后放入/data目录下:
└─data
└─LCCC-base-split
├─LCCC-base_test.json
├─LCCC-base_train.json
└─LCCC-base_valid.json
-
步骤四:修改run/run_s2s.py中的相关代码,包括:
-
设置训练参数,如训练轮数设为4:
‘‘‘ line 23~29 ’’’
conf = utils.dict2obj({
'early_stop_patience': None,
'just_save_best': False,
'n_epochs': 4,
'data_type': 'tfrecord',
# 'data_type': 'pkldata',
})
2. 选择字典文件名(通过仅保留LCCC项而注释其它,下同):
’’’ line 43~46 ’’’
self.token2id_dct = {
# 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/s2s_char2id.dct', use_line_no=True), # 自有数据
# 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/XHJ_s2s_char2id.dct', use_line_no=True), # 小黄鸡
'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/LCCC_s2s_char2id.dct', use_line_no=True), # LCCC
}
3. 选择模型(Transformer):
’’’ line 359~360 ’’’
rm_s2s = Run_Model_S2S('trans') # use transformer seq2seq
# rm_s2s = Run_Model_S2S('rnn_s2s') # use biGRU encoder + bah_attn + GRU decoder
4. 选择训练函数和指定batch_size:
’’’ line 362~369 ’’’
# 训练自有数据
# rm_s2s.train('s2s_ckpt_1', '../data/s2s_example_data.txt', preprocess_raw_data=preprocess_raw_data, batch_size=512) # train
# 训练小黄鸡
# rm_s2s.train('s2s_ckpt_XHJ1', '', preprocess_raw_data=preprocess_common_dataset_XiaoHJ, batch_size=512) # train
# 训练LCCC语料
rm_s2s.train('s2s_ckpt_LCCC1', '', preprocess_raw_data=preprocess_common_dataset_LCCC, batch_size=512) # train
5. 选择在训练完要测试模型时载入的ckpt名:
’’’ line 373 ’’’
rm_s2s.restore('s2s_ckpt_LCCC1') # for infer
-
步骤五:执行python run_s2s.py,开始进行训练。并且训练结束后会自动载入模型进行推断(输入问句,闲聊机器人给出回复):
训练截图:
此图是后面补的(重新运行训练),所以tfrecord等文件都已存在,另外请忽略速度(因为这是在cpu..)
训练4个epo后模型loss是:训练集3.32/测试集3.41(目测继续训练还能再降)。载入该模型进行推断。(推断时已集成了简单的回复后处理排序模块)
推断例子:
输入“你好”,模型原始输出(分数越大即绝对值越小越好):
输入“你好”,模型原始输出
输入“你好”,模型后处理排序后输出 (bad respond的分数设为极小值-1e7) :
输入“你好”,模型后处理排序后输出
输入“最近工作压力好大啊”,模型原始输出:
输入“最近工作压力好大啊”,模型原始输出
输入“最近工作压力好大啊”,模型后处理排序后输出:
输入“最近工作压力好大啊”,模型后处理排序后输出
-
步骤六:根据回复效果调整模型解码时的相关参数(model/s2s_model.py中):
# line 15~31
conf = utils.dict2obj({
'vocab_size': 4500,
'embed_size': 300,
'hidden_size': 300,
'num_heads': 6,
'num_encoder_layers': 6,
'num_decoder_layers': 6,
'dropout_rate': 0.2,
'lr': 1e-3,
'pretrain_emb': None,
# 以上是模型参数,以下是解码时相关参数
'beam_size': 40,
'max_decode_len': 50,
'eos_id': 2, # 句子结束符对应词典id(默认为2)
'gamma': 1, # 多样性鼓励因子
'num_group': 1, # 分组beam_search
'top_k': 30 # 分组beam_search首字符采样范围
})
(步骤五中的回复示例的参数设置即如上图所示)
主要参数说明
-
beam_size:beam_search时束的大小。越大则回复的多样性越好,但计算量及响应时间会增加。设为1 即使用greedy search
-
max_decode_len:最大解码长度
-
gamma:多样性鼓励因子,越大多样性越好。一般设为1。参考此论文
-
num_group:beam_search时设置分组数量,每组共享同一个开头字符,组间不同开头字符,以保持多样性。设为1即不分组
-
top_k:分组时每组开头字符从前k个字符中采样。要求top_k >= num_group
这里尝试展示一下设置num_group=10的效果,此时top_k=30会生效。仍以输入“最近工作压力好大啊”为例,模型原始输出:
设置分组后,输入“最近工作压力好大啊”,模型原始输出
输入“最近工作压力好大啊”,模型后处理排序后输出:
设置分组后,输入“最近工作压力好大啊”,模型后处理排序后输出
在最终实践应用时,可根据情况截取后处理排序后的前k个作为候选回复,并从中随机采样1个作为最终回复。
模型部署
部署可参考deploy/下的example.py和web_api.py文件。前者有如何载入已训练好的模型(ckpt或pbmodel)的示例,后者有利用tornado搭建模型webAPI服务的示例。具体可参考QizNLP项目所在github的说明^_^
结语
以上就是使用QizNLP框架来快速工程化实现一个基于Transformer的闲聊机器人流程。由于LCCC是中文领域中难得的高质量对话语料,尽管只进行了简单的训练,但机器人回复效果还不错。当然这也离不开在beam_search解码过程中使用的一些优化手段:如多样性鼓励、分组采样等。之后会围绕该框架继续介绍更多NLP实践,欢迎各位大佬star~~
参考
LCCC语料:A Large-Scale Chinese Short-Text Conversation Dataset.(github.com/thu-coai/CDi)
QizNLP:(github.com/Qznan/QizNLP)