机器人

用的模型为RNN(seq2seq),和前文的《RNN生成古诗词》《RNN生成音乐》类似。

       本次博客使用的数据集:影视对白数据集

       下载数据集后,解压提取dgk_shooter_min.conv文件;

         1)数据预处理:

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #coding=utf-8  
  2. import os  
  3. import random  
  4. from io import open  
  5. conv_path = 'dgk_shooter_min.conv'  
  6. #判断数据集是否存在?  
  7. if not os.path.exists(conv_path):  
  8.     print('数据集不存在')  
  9.     exit()  
  10.   
  11. # 数据集格式  
  12. """ 
  13. E 
  14. M 畹/华/吾/侄/ 
  15. M 你/接/到/这/封/信/的/时/候/ 
  16. M 不/知/道/大/伯/还/在/不/在/人/世/了/ 
  17. E 
  18. M 咱/们/梅/家/从/你/爷/爷/起/ 
  19. M 就/一/直/小/心/翼/翼/地/唱/戏/ 
  20. M 侍/奉/宫/廷/侍/奉/百/姓/ 
  21. M 从/来/不/曾/遭/此/大/祸/ 
  22. M 太/后/的/万/寿/节/谁/敢/不/穿/红/ 
  23. M 就/你/胆/儿/大/ 
  24. M 唉/这/我/舅/母/出/殡/ 
  25. M 我/不/敢/穿/红/啊/ 
  26. M 唉/呦/唉/呦/爷/ 
  27. M 您/打/得/好/我/该/打/ 
  28. M 就/因/为/没/穿/红/让/人/赏/咱/一/纸/枷/锁/ 
  29. M 爷/您/别/给/我/戴/这/纸/枷/锁/呀/ 
  30. E 
  31. M 您/多/打/我/几/下/不/就/得/了/吗/ 
  32. M 走/ 
  33. M 这/是/哪/一/出/啊/…/ / /这/是/ 
  34. M 撕/破/一/点/就/弄/死/你/ 
  35. M 唉/ 
  36. M 记/着/唱/戏/的/再/红/ 
  37. M 还/是/让/人/瞧/不/起/ 
  38. M 大/伯/不/想/让/你/挨/了/打/ 
  39. M 还/得/跟/人/家/说/打/得/好/ 
  40. M 大/伯/不/想/让/你/再/戴/上/那/纸/枷/锁/ 
  41. M 畹/华/开/开/门/哪/ 
  42. E 
  43. ... 
  44. """  
  45.   
  46. # 我首先使用文本编辑器sublime把dgk_shooter_min.conv文件编码转为UTF-8,一下子省了不少麻烦  
  47. convs = []  # 对话集合  
  48. with open(conv_path, encoding="utf8") as f:  
  49.     one_conv = []  # 一次完整对话  
  50.     for line in f:  
  51.         line = line.strip('\n').replace('/''')#将分隔符去掉  
  52.         if line == '':  
  53.            continue  
  54.         if line[0] == 'E':  
  55.            if one_conv:  
  56.               convs.append(one_conv)  
  57.            one_conv = []  
  58.         elif line[0] == 'M':  
  59.            one_conv.append(line.split(' ')[1])  
  60. #将对话转成utf-8格式,并将其保存在dgk_shooter_min.conv文件中  
  61.   
  62. """ 
  63. print(convs[:3])  # 个人感觉对白数据集有点不给力啊 
  64. [ ['畹华吾侄', '你接到这封信的时候', '不知道大伯还在不在人世了'], 
  65.   ['咱们梅家从你爷爷起', '就一直小心翼翼地唱戏', '侍奉宫廷侍奉百姓', '从来不曾遭此大祸', '太后的万寿节谁敢不穿红', '就你胆儿大', '唉这我舅母出殡', '我不敢穿红啊', '唉呦唉呦爷', '您打得好我该打', '就因为没穿红让人赏咱一纸枷锁', '爷您别给我戴这纸枷锁呀'], 
  66.   ['您多打我几下不就得了吗', '走', '这是哪一出啊 ', '撕破一点就弄死你', '唉', '记着唱戏的再红', '还是让人瞧不起', '大伯不想让你挨了打', '还得跟人家说打得好', '大伯不想让你再戴上那纸枷锁', '畹华开开门哪'], ....] 
  67. """  
  68.   
  69. # 把对话分成问与答  
  70. ask = []        # 问  
  71. response = []   # 答  
  72. for conv in convs:  
  73.   if len(conv) == 1:  
  74.      continue  
  75.   if len(conv) % 2 != 0:  # 奇数对话数, 转为偶数对话  
  76.      conv = conv[:-1]  
  77.   for i in range(len(conv)):  
  78.      if i % 2 == 0:  
  79.         ask.append(conv[i])#偶数对,填写问题  
  80.      else:  
  81.         response.append(conv[i])#回答  
  82.   
  83. """ 
  84. print(len(ask), len(response)) 
  85. print(ask[:3]) 
  86. print(response[:3]) 
  87. ['畹华吾侄', '咱们梅家从你爷爷起', '侍奉宫廷侍奉百姓'] 
  88. ['你接到这封信的时候', '就一直小心翼翼地唱戏', '从来不曾遭此大祸'] 
  89. """  
  90.   
  91.   
  92. def convert_seq2seq_files(questions, answers, TESTSET_SIZE=8000):  
  93.     # 创建文件  
  94.     train_enc = open('train.enc''w')  # 问  
  95.     train_dec = open('train.dec''w')  # 答  
  96.     test_enc = open('test.enc''w')  # 问  
  97.     test_dec = open('test.dec''w')  # 答  
  98.   
  99.     # 选择8000数据作为测试数据  
  100.     test_index = random.sample([i for i in range(len(questions))], TESTSET_SIZE)  
  101.   
  102.     for i in range(len(questions)):  
  103.         if i in test_index:#创建测试文件  
  104.             test_enc.write(questions[i] + '\n')  
  105.             test_dec.write(answers[i] + '\n')  
  106.         else:#创建训练文件  
  107.             train_enc.write(questions[i] + '\n')  
  108.             train_dec.write(answers[i] + '\n')  
  109.         if i % 1000 == 0:#表示处理了多少个i  
  110.             print(len(range(len(questions))), '处理进度:', i)  
  111.   
  112.     train_enc.close()  
  113.     train_dec.close()  
  114.     test_enc.close()  
  115.     test_dec.close()  
  116.   
  117.   
  118. convert_seq2seq_files(ask, response)  
  119. # 生成的*.enc文件保存了问题  
  120. # 生成的*.dec文件保存了回答  
           2)创建词汇表

  

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #coding=utf-8  
  2. # 前一步生成的问答文件路径  
  3. train_encode_file = 'train.enc'  
  4. train_decode_file = 'train.dec'  
  5. test_encode_file = 'test.enc'  
  6. test_decode_file = 'test.dec'  
  7.   
  8. print('开始创建词汇表...')  
  9. # 特殊标记,用来填充标记对话  
  10. PAD = "__PAD__"  
  11. GO = "__GO__"  
  12. EOS = "__EOS__"  # 对话结束  
  13. UNK = "__UNK__"  # 标记未出现在词汇表中的字符  
  14. START_VOCABULART = [PAD, GO, EOS, UNK]  
  15. PAD_ID = 0  
  16. GO_ID = 1  
  17. EOS_ID = 2  
  18. UNK_ID = 3  
  19. # 参看tensorflow.models.rnn.translate.data_utils  
  20.   
  21. vocabulary_size = 5000  
  22.   
  23.   
  24. # 生成词汇表文件  
  25. def gen_vocabulary_file(input_file, output_file):  
  26.     vocabulary = {}  
  27.     with open(input_file) as f:  
  28.         counter = 0  
  29.         for line in f:  
  30.            counter += 1  
  31.            tokens = [word for word in line.strip()]  
  32.            for word in tokens:  
  33.                if word in vocabulary:  
  34.                   vocabulary[word] += 1  
  35.                else:  
  36.                   vocabulary[word] = 1  
  37.         vocabulary_list = START_VOCABULART + sorted(vocabulary, key=vocabulary.get, reverse=True)  
  38.         # 取前5000个常用汉字, 应该差不多够用了(额, 好多无用字符, 最好整理一下. 我就不整理了)  
  39.         if len(vocabulary_list) > 5000:  
  40.            vocabulary_list = vocabulary_list[:5000]  
  41.         print(input_file + " 词汇表大小:", len(vocabulary_list))  
  42.         with open(output_file, "w") as ff:  
  43.            for word in vocabulary_list:  
  44.                ff.write(word + "\n")  
  45.   
  46.   
  47. gen_vocabulary_file(train_encode_file, "train_encode_vocabulary")  
  48. gen_vocabulary_file(train_decode_file, "train_decode_vocabulary")  
  49.   
  50. train_encode_vocabulary_file = 'train_encode_vocabulary'  
  51. train_decode_vocabulary_file = 'train_decode_vocabulary'  
  52.   
  53. print("对话转向量...")  
  54.   
  55.   
  56. # 把对话字符串转为向量形式  
  57. def convert_to_vector(input_file, vocabulary_file, output_file):  
  58.     tmp_vocab = []  
  59.     with open(vocabulary_file, "r") as f:  
  60.         tmp_vocab.extend(f.readlines())  
  61.     tmp_vocab = [line.strip() for line in tmp_vocab]  
  62.     vocab = dict([(x, y) for (y, x) in enumerate(tmp_vocab)])  
  63.     # {'硕': 3142, 'v': 577, 'I': 4789, '\ue796': 4515, '拖': 1333, '疤': 2201 ...}  
  64.     output_f = open(output_file, 'w')  
  65.     with open(input_file, 'r') as f:  
  66.         for line in f:  
  67.             line_vec = []  
  68.             for words in line.strip():  
  69.                 line_vec.append(vocab.get(words, UNK_ID))  
  70.             output_f.write(" ".join([str(num) for num in line_vec]) + "\n")  
  71.     output_f.close()  
  72.   
  73.   
  74. convert_to_vector(train_encode_file, train_encode_vocabulary_file, 'train_encode.vec')  
  75. convert_to_vector(train_decode_file, train_decode_vocabulary_file, 'train_decode.vec')  
  76.   
  77. convert_to_vector(test_encode_file, train_encode_vocabulary_file, 'test_encode.vec')  
  78. convert_to_vector(test_decode_file, train_decode_vocabulary_file, 'test_decode.vec')  
生成的train_encode.vec和train_decode.vec用于训练,对应的词汇表是train_encode_vocabulary和train_decode_vocabulary。

        3)训练

[html]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #coding=utf-8  
  2. import tensorflow as tf  # 0.12  
  3. from tensorflow.models.rnn.translate import seq2seq_model  
  4. import os  
  5. import numpy as np  
  6. import math  
  7. #导入文件  
  8. PAD_ID = 0  
  9. GO_ID = 1  
  10. EOS_ID = 2  
  11. UNK_ID = 3  
  12.   
  13. train_encode_vec = 'train_encode.vec'  
  14. train_decode_vec = 'train_decode.vec'  
  15. test_encode_vec = 'test_encode.vec'  
  16. test_decode_vec = 'test_decode.vec'  
  17.   
  18. # 词汇表大小5000  
  19. vocabulary_encode_size = 5000  
  20. vocabulary_decode_size = 5000  
  21.   
  22. buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]  
  23. layer_size = 256  # 每层大小  
  24. num_layers = 3  # 层数  
  25. batch_size = 64  
  26.   
  27.   
  28. # 读取*dencode.vec和*decode.vec数据(数据还不算太多, 一次读人到内存)  
  29. def read_data(source_path, target_path, max_size=None):  
  30.     data_set = [[] for _ in buckets]#生成了[[],[],[],[]],即当值与参数不一样  
  31.     with tf.gfile.GFile(source_path, mode="r") as source_file:#以读格式打开源文件(source_file)  
  32.         with tf.gfile.GFile(target_path, mode="r") as target_file:#以读格式打开目标文件  
  33.             source, target = source_file.readline(), target_file.readline()#只读取一行  
  34.             counter = 0#计数器为0  
  35.             while source and target and ( not max_size or counter < max_size):#当读入的还存在时  
  36.                 counter += 1  
  37.                 source_ids = [int(x) for x in source.split()]#source的目标序列号,默认分隔符为空格,组成了一个源序列  
  38.                 target_ids = [int(x) for x in target.split()]#target组成一个目标序列,为目标序列  
  39.                 target_ids.append(EOS_ID)#加上结束标记的序列号  
  40.                 for bucket_id, (source_size, target_size) in enumerate(buckets):#enumerate()遍历序列中的元素和其下标  
  41.                     if len(source_ids) < source_size and len(target_ids) < target_size:#判断是否超越了最大长度  
  42.                         data_set[bucket_id].append([source_ids, target_ids])#读取到数据集文件中区  
  43.                         break#一次即可,跳出当前循环  
  44.                 source, target = source_file.readline(), target_file.readline()#读取了下一行  
  45.     return data_set  
  46.   
  47.   
  48. model = seq2seq_model.Seq2SeqModel(source_vocab_size=vocabulary_encode_sizetarget_vocab_size=vocabulary_decode_size,  
  49.                                    buckets=bucketssize=layer_sizenum_layers=num_layersmax_gradient_norm=5.0,  
  50.                                    batch_size=batch_sizelearning_rate=0.5, learning_rate_decay_factor=0.97,  
  51.                                    forward_only=False)  
  52.   
  53. config = tf.ConfigProto()  
  54. config.gpu_options.allocator_type = 'BFC'  # 防止 out of memory  
  55.   
  56. with tf.Session(config=config) as sess:  
  57.     # 恢复前一次训练  
  58.     ckpt = tf.train.get_checkpoint_state('.')  
  59.     if ckpt != None:  
  60.         print(ckpt.model_checkpoint_path)  
  61.         model.saver.restore(sess, ckpt.model_checkpoint_path)  
  62.     else:  
  63.         sess.run(tf.global_variables_initializer())  
  64.   
  65.     train_set = read_data(train_encode_vec, train_decode_vec)  
  66.     test_set = read_data(test_encode_vec, test_decode_vec)  
  67.   
  68.     train_bucket_sizes = [len(train_set[b]) for b in range(len(buckets))]#分别计算出训练集中的长度【1,2,3,4】  
  69.     train_total_size = float(sum(train_bucket_sizes))#训练实例总数  
  70.     train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]#计算了之前所有的数的首战百分比  
  71.   
  72.     loss = 0.0#损失置位0  
  73.     total_step = 0  
  74.     previous_losses = []  
  75.     # 一直训练,每过一段时间保存一次模型  
  76.     while True:  
  77.         random_number_01 = np.random.random_sample()#每一次循环结果不一样  
  78.         #选出最小的大于随机采样的值的索引号  
  79.         bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])  
  80.   
  81.         encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)  
  82.         #get_batch()函数首先获取bucket的encoder_size与decoder_size  
  83.         _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)#损失  
  84.   
  85.         loss += step_loss / 500  
  86.         total_step += 1  
  87.   
  88.         print(total_step)  
  89.         if total_step % 500 == 0:  
  90.             print(model.global_step.eval(), model.learning_rate.eval(), loss)  
  91.   
  92.     # 如果模型没有得到提升,减小learning rate  
  93.             if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):#即损失比以前的大则降低学习率  
  94.                 sess.run(model.learning_rate_decay_op)  
  95.             previous_losses.append(loss)  
  96.     # 保存模型  
  97.             checkpoint_path = "chatbot_seq2seq.ckpt"  
  98.             model.saver.save(sess, checkpoint_path, global_step=model.global_step)  
  99.             #返回路径checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))  
  100.             loss = 0.0#置当前损失为0  
  101.     # 使用测试数据评估模型  
  102.             for bucket_id in range(len(buckets)):  
  103.                 if len(test_set[bucket_id]) == 0:  
  104.                     continue  
  105.                 #获取当前bucket的encoder_inputs, decoder_inputs, target_weights  
  106.                 encoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)  
  107.                 #计算bucket_id的损失权重  
  108.                 _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)  
  109.                 eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')  
  110.                 print(bucket_id, eval_ppx)#输出的是bucket_id与eval_ppx  

   这个阶段最好要用GPU运行,不然需要很长时间。

          4)使用训练好的模型

                 

[html]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. <span style="font-size:14px;">#coding=utf-8  
  2. import tensorflow as tf  # 0.12  
  3. from tensorflow.models.rnn.translate import seq2seq_model#在翻译模型中,引入seq2seq_model  
  4. import os  
  5. import numpy as np  
  6.   
  7. PAD_ID = 0  
  8. GO_ID = 1  
  9. EOS_ID = 2  
  10. UNK_ID = 3  
  11. #词汇表路径path  
  12. train_encode_vocabulary = 'train_encode_vocabulary'  
  13. train_decode_vocabulary = 'train_decode_vocabulary'  
  14.   
  15. #读取词汇表  
  16. def read_vocabulary(input_file):  
  17.     tmp_vocab = []  
  18.     with open(input_file, "r") as f:  
  19.         tmp_vocab.extend(f.readlines())#打开的文件全部读入input_file中  
  20.     tmp_vocab = [line.strip() for line in tmp_vocab]#转换成列表  
  21.     vocab = dict([(x, y) for (y, x) in enumerate(tmp_vocab)])  
  22.     return vocab, tmp_vocab#返回字典,列表  
  23.   
  24.   
  25. vocab_en, _, = read_vocabulary(train_encode_vocabulary)#得到词汇字典  
  26. _, vocab_de, = read_vocabulary(train_decode_vocabulary)#得到词汇列表  
  27.   
  28. # 词汇表大小5000  
  29. vocabulary_encode_size = 5000  
  30. vocabulary_decode_size = 5000  
  31.   
  32. buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]  
  33. layer_size = 256  # 每层大小  
  34. num_layers = 3  # 层数  
  35. batch_size = 1  
  36.   
  37. model = seq2seq_model.Seq2SeqModel(source_vocab_size=vocabulary_encode_sizetarget_vocab_size=vocabulary_decode_size,  
  38.                                    buckets=bucketssize=layer_sizenum_layers=num_layersmax_gradient_norm=5.0,  
  39.                                    batch_size=batch_sizelearning_rate=0.5, learning_rate_decay_factor=0.99,  
  40.                                    forward_only=True)  
  41. #模型说明:源,目标词汇尺寸=vocabulary_encode(decode)_size;batch_size:训练期间使用的批次的大小;#forward_only:仅前向不传递误差  
  42.   
  43. model.batch_size = 1#batch_size=1  
  44.   
  45. with tf.Session() as sess:#打开作为一次会话  
  46.     # 恢复前一次训练  
  47.     ckpt = tf.train.get_checkpoint_state('.')#从检查点文件中返回一个状态(ckpt)  
  48.     #如果ckpt存在,输出模型路径  
  49.     if ckpt != None:  
  50.         print(ckpt.model_checkpoint_path)  
  51.         model.saver.restore(sess, ckpt.model_checkpoint_path)#储存模型参数  
  52.     else:  
  53.         print("没找到模型")  
  54.     #测试该模型的能力  
  55.     while True:  
  56.         input_string = input('me > ')  
  57.     # 退出  
  58.         if input_string == 'quit':  
  59.            exit()  
  60.   
  61.         input_string_vec = []#输入字符串向量化  
  62.         for words in input_string.strip():  
  63.             input_string_vec.append(vocab_en.get(words, UNK_ID))#get()函数:如果words在词表中,返回索引号;否则,返回UNK_ID  
  64.         bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])#保留最小的大于输入的bucket的id  
  65.         encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)  
  66.         #get_batch(A,B):两个参数,A为大小为len(buckets)的元组,返回了指定bucket_id的encoder_inputs,decoder_inputs,target_weights  
  67.         _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)  
  68.         #得到其输出  
  69.         outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]#求得最大的预测范围列表  
  70.         if EOS_ID in outputs:#如果EOS_ID在输出内部,则输出列表为[,,,,:End]  
  71.             outputs = outputs[:outputs.index(EOS_ID)]  
  72.   
  73.         response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])#转为解码词汇分别添加到回复中  
  74.         print('AI > ' + response)#输出回复</span>  
结果:

转载地址: http://blog.topspeedsnail.com/archives/10735

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值