自然语言处理:基于seq2seq+bucket版(四)

自然语言处理:分词(断词)与关键词提取方法(一)
自然语言处理:scrapy爬取关键词信息(二)
自然语言处理:问答语料生成词汇表,词转向量(三)

1 处理数据集

  • Demo
import tensorflow as tf
# 语料库长度桶结构
# (5, 10): 5问题长度,10回答长度
buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
def read_data(question_path, answer_path, max_size=None):
	# 提取桶结构
	data_set = [[] for _ in buckets]
	with tf.gfile.GFile(question_path, mode="r") as question_file:
		with tf.gfile.GFile(answer_path, mode="r") as answer_file:
			question, answer = question_file.readline(), answer_file.readline()
			counter = 0
			while question and answer and (not max_size or counter < max_size):
				counter += 1
				question_ids = [int(x) for x in question.split()]
				answer_ids = [int(x) for x in answer.split()]
				answer_ids.append(EOS_ID)
				for bucket_id, (question_size, answer_size) in enumerate(buckets):
					if len(question_ids) < question_size and len(answer_ids) < answer_size:
						data_set[bucket_id].append([question_ids, answer_ids])
						break
				question, answer = question_file.readline(), answer_file.readline()
	return data_set
# 词向量数据集
train_set = read_data(train_encode_vec, train_decode_vec)
print("Train set: {}".format(train_set))
# 每个桶中对话数量,一问一答为一次完整对话
train_bucket_sizes = [len(train_set[b]) for b in range(len(buckets))]
print("train bucket sizes: {}".format(train_bucket_sizes))
# 总对话数量
total_size = float(sum(train_bucket_sizes))
print("total size: {}".format(total_size))
for i in train_set:
	print("data: {}".format(i))
print("lenght of data set: {}".format(len(train_set)))
# 每个桶中对话数量占比
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / total_size for i in range(len(train_bucket_sizes))]
print("train buckets scale: {}".format(train_buckets_scale))
  • Result

[[[[7, 5], [36, 8, 29, 12, 2]], [[24, 8, 9], [4, 37, 81, 113, 114, 198, 4, 2]], [[154, 155], [30, 30, 4, 38, 82, 62, 2]], [[44, 45, 5, 12], [203, 204, 205, 206, 31, 116, 117, 2]], [[44, 45, 5, 12], [8, 214, 215, 7, 63, 119, 15, 2]], [[44, 45, 5, 12], [216, 217, 2]], [[27, 4, 28, 69], [4, 121, 122, 2]], [[162, 98, 163], [236, 237, 15, 2]], [[4], [2]], [[11, 11, 11], [243, 244, 4, 2]], [[13, 20], [249, 14, 2]], [[13, 102], [258, 259, 17, 8, 31, 260, 2]], [[5, 51], [72, 84, 95, 2]], [[27, 4, 7, 71], [4, 14, 263, 264, 32, 2]], [[28, 48, 72], [14, 21, 6, 15, 2]], [[182, 29, 6], [4, 37, 92, 54, 7, 2]], [[6, 7, 28, 183], [28, 281, 282, 2]], [[4, 5, 106], [4, 8, 101, 2]], [[4, 55, 5], [4, 9, 8, 157, 6, 71, 8, 2]], [[109, 110, 111], [20, 309, 74, 2]], [[19, 191, 14], [6, 9, 18, 4, 2]], [[76, 5, 12], [95, 107, 8, 6, 63, 119, 2]], [[14, 14], [16, 11, 2]], [[11, 11, 11], [26, 6, 33, 15, 2]], [[194, 195, 13, 23], [5, 5, 2]], [[79, 196], [5, 5, 2]], [[], [5, 5, 2]], [[113, 113], [314, 2]], [[40, 4, 23], [6, 20, 87, 2]], [[11, 11], [26, 6, 33, 2]], [[58], [56, 2]], [[6], [6, 29, 12, 2]], [[117, 210, 58], [5, 5, 2]], [[212], [5, 5, 2]], [[119, 119], [5, 5, 2]], [[], [5, 5, 2]], [[213], [5, 5, 2]], [[], [5, 5, 2]], [[], [5, 5, 2]], [[], [5, 5, 2]], [[214, 215], [20, 9, 10, 17, 335, 336, 337, 2]], [[17, 18, 216, 4], [5, 5, 2]], [[217], [338, 2]], [[218], [339, 340, 341, 342, 2]], [[219], [164, 164, 2]], [[41], [343, 2]], [[120], [88, 2]], [[224], [167, 167, 2]], [[41], [168, 2]], [[226], [352, 2]], [[122], [353, 2]], [[], [5, 5, 2]], [[32, 32, 32], [5, 5, 2]], [[22, 8, 77, 78], [36, 29, 12, 13, 360, 153, 2]], [[22, 8], [361, 362, 363, 364, 176, 2]], [[16, 77, 78], [177, 78, 104, 77, 9, 178, 2]], [[19, 79, 10], [74, 74, 34, 7, 2]], [[87, 4], [5, 5, 2]], [[4, 23], [4, 23, 13, 33, 4, 13, 180, 180, 2]], [[4, 129, 129], [378, 182, 379, 2]], [[244], [16, 181, 13, 385, 165, 386, 2]], [[245, 246], [387, 6, 17, 183, 38, 9, 27, 134, 2]], [[31], [404, 141, 2]], [[11, 11, 11, 11], [26, 26, 6, 33, 2]], [[4, 62, 62, 6], [98, 67, 8, 4, 111, 157, 2]], [[55, 56, 16], [13, 2]], [[16], [23, 2]], [[56], [23, 23, 2]], [[4, 23], [2]], [[4, 140, 141, 8], [4, 123, 427, 428, 429, 2]], [[62, 62], [111, 111, 2]], [[4, 59, 264, 9], [17, 432, 2]], [[66, 267], [438, 193, 2]], [[268, 98, 269], [439, 440, 441, 2]], [[84, 84], [443, 190, 2]], [[137, 22, 8], [45, 99, 26, 26, 65, 2]], [[143, 37], [5, 5, 2]], [[13, 20], [177, 78, 104, 77, 9, 178, 2]], [[31, 90], [195, 195, 2]], [[15, 290, 291], [9, 4, 8, 100, 55, 7, 2]], [[147, 31, 148, 149], [69, 12, 10, 2]], [[24, 8, 144], [9, 43, 44, 2]], [[147], [468, 109, 14, 21, 6, 2]], [[31, 10, 8], [30, 8, 7, 2]], [[91, 4, 23], [4, 23, 33, 2]], [[309], [69, 12, 10, 55, 469, 9, 27, 2]], [[35, 57, 83, 94], [27, 46, 27, 46, 2]]], [[[4, 95, 66, 12, 8], [199, 115, 9, 8, 4, 8, 200, 201, 202, 2]], [[30, 8, 156], [207, 208, 209, 118, 38, 210, 9, 17, 83, 7, 211, 15, 2]], [[44, 45, 5, 12], [8, 4, 212, 213, 84, 18, 7, 16, 11, 7, 63, 85, 15, 2]], [[13, 96, 157, 4, 16, 67, 34, 26, 8], [63, 218, 8, 64, 7, 2]], [[35, 4, 19, 36, 10], [4, 86, 224, 6, 20, 87, 7, 2]], [[4, 28, 97, 6, 25, 159], [6, 225, 67, 226, 227, 228, 2]], [[4, 5, 160, 161, 14], [30, 30, 2]], [[99, 46, 5, 12, 10, 47, 164], [8, 4, 20, 238, 49, 239, 240, 7, 68, 241, 2]], [[4, 100, 7, 100], [6, 50, 4, 24, 50, 31, 83, 50, 125, 4, 126, 242, 2]], [[46, 165, 166, 167, 168], [245, 246, 247, 129, 248, 2]], [[28, 169, 170, 4, 28, 48, 8], [14, 90, 130, 131, 48, 132, 70, 90, 21, 133, 133, 2]], [[4, 10, 180, 181, 37, 5, 12], [277, 148, 99, 7, 278, 279, 280, 2]], [[4, 103, 27, 4, 184, 6], [4, 23, 13, 283, 284, 2]], [[15, 185, 186], [9, 19, 81, 113, 4, 286, 6, 54, 74, 96, 2]], [[4, 5, 105, 14, 51], [85, 101, 287, 7, 288, 289, 290, 150, 291, 2]], [[4, 188, 189, 53, 9], [299, 300, 7, 6, 301, 105, 152, 10, 2]], [[4, 38, 10, 76, 5, 12], [302, 303, 304, 2]], [[109, 110, 111, 5, 69, 190], [17, 39, 37, 308, 7, 64, 11, 2]], [[192, 63, 7, 193, 8], [5, 5, 2]], [[4, 16, 68, 34, 26, 56], [6, 158, 52, 32, 2]], [[15, 17, 40, 4, 19, 36], [6, 91, 322, 323, 2]], [[80, 29, 4], [11, 34, 162, 58, 24, 330, 6, 331, 332, 11, 34, 2]], [[11, 11], [51, 45, 73, 10, 163, 333, 334, 16, 11, 2]], [[206, 81, 13, 20, 9, 116, 207, 10], [5, 5, 2]], [[83], [16, 11, 13, 29, 12, 165, 344, 31, 108, 14, 345, 2]], [[82, 7, 31, 43, 9], [4, 76, 7, 8, 41, 346, 2]], [[225], [156, 349, 350, 138, 76, 351, 139, 38, 23, 13, 34, 2]], [[59, 7, 50, 4], [49, 49, 4, 42, 10, 172, 173, 4, 47, 2]], [[4, 53, 21, 228, 22, 8], [365, 366, 367, 2]], [[6, 49, 18, 229, 26], [5, 5, 2]], [[6, 49, 18, 67, 34, 26], [142, 7, 27, 27, 176, 368, 86, 2]], [[87, 59, 22, 8, 235, 236, 10, 128, 128], [5, 5, 2]], [[6, 237, 58, 15, 15], [372, 179, 373, 179, 374, 2]], [[243, 33, 6, 64, 130, 88, 8], [61, 7, 2]], [[4, 10, 130, 88, 5, 12], [380, 381, 382, 383, 78, 384, 2]], [[15, 42, 27, 4, 247], [30, 2]], [[4, 132, 61, 24, 8, 30, 8, 248], [6, 28, 25, 71, 392, 89, 15, 2]], [[4, 10, 132, 28, 69], [14, 36, 184, 109, 393, 75, 7, 394, 395, 2]], [[13, 96, 4, 5, 106, 80, 5, 14], [4, 8, 185, 101, 97, 2]], [[133, 14, 52, 5, 22, 8], [79, 79, 2]], [[134, 134, 249, 65, 33], [79, 79, 7, 186, 82, 85, 2]], [[251, 88, 7, 252, 253], [16, 11, 2]], [[135, 4], [58, 187, 4, 46, 187, 175, 10, 24, 56, 6, 130, 2]], [[254, 73, 9, 135, 50, 4], [49, 49, 4, 42, 10, 172, 173, 4, 47, 2]], [[7, 123, 124], [102, 102, 6, 7, 55, 47, 6, 43, 44, 128, 403, 7, 2]], [[4, 23, 4, 255, 79, 63, 5, 4, 23], [4, 8, 405, 406, 20, 407, 2]], [[4, 5, 14, 52], [9, 163, 18, 7, 6, 8, 4, 25, 408, 22, 2]], [[6, 18, 256, 4, 257, 133], [16, 11, 4, 42, 10, 188, 188, 2]], [[126, 127, 258, 9, 4], [11, 34, 409, 185, 95, 107, 61, 410, 411, 7, 2]], [[4, 5, 25, 136, 14], [16, 11, 4, 18, 6, 2]], [[4, 38, 24, 8, 85, 55, 137, 74, 75], [412, 103, 413, 48, 2]], [[6, 7, 74, 75, 4, 9], [4, 162, 58, 24, 9, 414, 415, 6, 2]], [[4, 5, 25, 136, 138, 138, 10, 14], [416, 417, 418, 8, 39, 116, 117, 2]], [[131, 259, 9, 42, 4], [6, 71, 45, 99, 25, 7, 19, 419, 420, 8, 9, 8, 2]], [[56, 16, 81, 81, 4], [421, 189, 182, 186, 97, 2]], [[4, 5, 19, 39, 262, 8], [4, 8, 31, 189, 426, 2]], [[15, 42, 4, 263, 9, 118, 89], [36, 8, 430, 431, 7, 2]], [[30, 7, 5, 91, 60], [4, 76, 42, 10, 442, 9, 2]], [[21, 33, 4, 5, 270, 271, 54, 142, 10], [4, 21, 6, 10, 2]], [[24, 8, 116, 272, 143, 37], [57, 57, 28, 170, 444, 7, 445, 446, 447, 72, 23, 56, 2]], [[6, 7, 40, 4, 40, 40], [6, 59, 4, 59, 6, 59, 4, 59, 9, 59, 2]], [[107, 108, 29, 6, 21], [6, 19, 4, 69, 12, 448, 152, 46, 2]], [[273, 4, 10, 144, 274], [98, 67, 4, 118, 193, 103, 52, 37, 2]], [[90, 284, 285, 286, 29, 6, 287], [459, 6, 7, 460, 96, 36, 12, 2]], [[16, 92, 288, 9, 7, 289, 18, 9], [5, 5, 2]], [[31, 292, 43, 4, 30, 85, 92, 6], [461, 18, 6, 462, 2]], [[4, 53, 293, 89, 146, 25, 86, 6, 294], [5, 5, 2]], [[302, 93, 303, 41, 94, 57, 57, 41, 150], [110, 35, 196, 2]], [[305, 306, 13, 20, 122, 94, 307, 83, 308], [5, 5, 2]]], [[[151, 43, 152, 32, 25, 153, 63, 64, 15, 65, 33], [30, 61, 7, 2]], [[27, 4, 16, 68, 34, 26, 8], [23, 13, 32, 4, 219, 19, 16, 11, 17, 39, 11, 120, 18, 4, 65, 2]], [[4, 158, 5, 12], [4, 25, 25, 8, 220, 221, 123, 66, 222, 223, 7, 16, 11, 6, 15, 2]], [[4, 19, 36, 10], [124, 88, 48, 124, 88, 48, 229, 230, 10, 231, 232, 8, 40, 89, 14, 233, 7, 234, 235, 10, 2]], [[11, 11], [26, 6, 33, 6, 26, 26, 10, 4, 19, 69, 12, 127, 128, 6, 41, 41, 2]], [[171, 4, 7, 49, 6, 6, 17, 7, 101, 4, 172], [36, 4, 24, 91, 51, 11, 134, 86, 2]], [[174, 175, 5, 27, 6, 7, 53, 47, 70, 176, 4, 47, 70, 17, 71], [4, 42, 10, 4, 143, 38, 9, 50, 10, 71, 94, 2]], [[4, 48, 72], [37, 121, 122, 261, 262, 53, 9, 43, 44, 21, 144, 65, 6, 14, 144, 65, 32, 2]], [[4, 54, 177, 37, 9], [9, 19, 82, 62, 145, 11, 145, 11, 53, 8, 20, 265, 96, 20, 266, 97, 2]], [[49, 6, 48, 72], [6, 8, 39, 45, 267, 21, 6, 73, 268, 20, 269, 41, 41, 41, 2]], [[4, 64, 38, 178, 179, 51], [98, 67, 10, 270, 271, 272, 273, 146, 274, 147, 146, 275, 147, 276, 2]], [[6, 21, 73, 4, 104, 25, 39, 17, 74, 75, 4, 9], [285, 100, 7, 149, 2]], [[6, 24, 8, 187, 4, 15, 26], [151, 294, 295, 21, 75, 10, 296, 297, 136, 137, 24, 61, 298, 10, 2]], [[4, 23], [315, 56, 6, 159, 316, 317, 318, 319, 6, 24, 320, 160, 34, 160, 34, 321, 55, 10, 57, 57, 57, 2]], [[197, 5, 19, 36], [4, 20, 87, 324, 325, 10, 6, 91, 28, 13, 326, 62, 9, 62, 4, 93, 28, 327, 6, 328, 161, 329, 2]], [[4, 19, 36, 7, 198, 199, 6, 18, 7, 114, 200, 4, 35], [5, 5, 2]], [[201, 29, 4, 202, 203, 57, 204, 205, 4, 17, 97, 13, 20, 13, 20, 115, 43, 9], [5, 5, 2]], [[37, 82, 208, 9, 209, 7, 73, 4, 29, 6, 54, 10, 9], [5, 5, 2]], [[55, 95, 66, 103, 118, 9, 211, 4, 117, 58], [5, 5, 2]], [[220, 4, 22, 8, 32, 221, 222, 223, 17, 5, 21, 4, 7, 121], [5, 5, 2]], [[21, 4, 7, 121], [9, 19, 46, 347, 9, 27, 4, 166, 17, 115, 166, 9, 19, 348, 10, 4, 46, 2]], [[59, 13, 20, 35, 9, 84], [4, 38, 21, 54, 169, 12, 354, 4, 170, 54, 10, 171, 171, 47, 2]], [[15, 123, 124, 99, 46, 17, 30, 85, 42, 6, 86, 9], [5, 5, 2]], [[6, 18, 7, 18, 65, 227, 35, 21, 6, 68, 34, 26], [175, 2]], [[125, 86, 230, 67, 231, 42, 126, 127, 232, 233, 234], [369, 370, 371, 2]], [[104, 25, 238, 101, 4, 38, 60, 239, 240, 10, 7, 71], [23, 17, 375, 376, 6, 127, 58, 2]], [[4, 82, 241, 242, 9], [51, 11, 73, 4, 78, 181, 377, 4, 73, 51, 11, 9, 75, 151, 2]], [[4, 5, 7, 5, 131], [4, 8, 9, 8, 105, 114, 388, 7, 161, 75, 72, 8, 105, 389, 50, 7, 390, 391, 58, 2]], [[4, 250, 105, 5, 12], [8, 17, 39, 126, 396, 397, 398, 399, 400, 13, 184, 109, 401, 7, 20, 402, 64, 150, 110, 35, 70, 80, 2]], [[6, 5, 139, 39, 61], [120, 12, 422, 148, 47, 29, 12, 53, 9, 92, 54, 169, 7, 125, 129, 94, 17, 83, 2]], [[46, 47, 10, 260, 4, 5, 261, 39, 61, 80, 5, 139, 39, 61], [423, 424, 4, 154, 77, 9, 190, 6, 7, 425, 2]], [[4, 265, 90, 9], [8, 32, 8, 32, 31, 108, 66, 433, 10, 31, 108, 66, 18, 16, 11, 10, 2]], [[275, 276, 4, 145, 10, 60, 7, 277, 278, 279, 280], [4, 8, 449, 450, 94, 29, 12, 53, 81, 9, 451, 4, 2]], [[4, 21, 30, 5, 281, 282, 54, 142, 7, 145, 283, 10], [5, 5, 2]], [[92, 4, 148, 149, 102, 115, 9, 140, 33, 295, 89, 146, 296], [5, 5, 2]], [[297, 141, 298, 299, 6, 300, 114, 16, 14, 301, 16, 52], [5, 5, 2]]], [[[4, 18, 70, 50, 6, 51, 173, 52], [68, 135, 8, 250, 51, 11, 7, 42, 251, 252, 253, 136, 137, 21, 21, 14, 6, 68, 135, 7, 138, 254, 139, 14, 255, 140, 141, 32, 256, 9, 92, 93, 142, 52, 257, 10, 2]], [[24, 8, 107, 108], [6, 102, 22, 4, 18, 22, 6, 13, 22, 103, 104, 22, 4, 18, 22, 6, 13, 22, 292, 22, 293, 22, 22, 22, 2]], [[22, 8, 77, 78], [40, 8, 14, 76, 28, 9, 305, 64, 68, 7, 55, 153, 306, 8, 154, 155, 155, 7, 156, 307, 52, 37, 84, 18, 106, 53, 9, 43, 44, 6, 72, 100, 8, 45, 15, 2]], [[4, 38, 10, 27, 112, 76, 5, 12], [45, 310, 89, 14, 14, 6, 311, 90, 77, 140, 7, 11, 24, 8, 6, 7, 40, 6, 19, 312, 40, 27, 313, 40, 18, 40, 43, 44, 10, 149, 2]], [[17, 5, 6, 87, 5, 50, 4, 9], [4, 158, 52, 9, 8, 355, 356, 28, 66, 357, 27, 358, 7, 174, 6, 13, 29, 12, 21, 359, 6, 19, 93, 28, 43, 44, 7, 174, 2]], [[266, 29, 6, 125, 112, 91, 60], [434, 6, 25, 159, 191, 24, 106, 39, 435, 192, 15, 6, 25, 24, 56, 6, 106, 12, 183, 191, 15, 6, 25, 8, 107, 25, 47, 436, 6, 33, 15, 437, 2]], [[13, 20], [16, 452, 6, 453, 194, 12, 143, 454, 4, 7, 194, 192, 19, 60, 10, 60, 10, 60, 10, 60, 15, 60, 455, 456, 457, 458, 2]], [[304, 93, 41, 150, 93, 120], [196, 463, 464, 112, 112, 110, 465, 48, 466, 35, 467, 80, 70, 132, 80, 70, 197, 197, 35, 112, 35, 131, 80, 35, 168, 35, 2]]]]
train bucket sizes: [87, 69, 36, 8]
total size: 200.0
lenght of data set: 4
train buckets scale: [0.435, 0.78, 0.96, 1.0]
  • Analysis
    (1) 设定4个桶结构,即将问答分成4个部分,每个同种存放对应的问答数据集[87, 69, 36, 8]四个桶中分别有87组对话,69组对话,36组对话,8组对话;
    (2) 训练词数据集符合桶长度则输入对应值,不符合桶长度,则为空;
    (3) 对话数量占比:[0.435, 0.78, 0.96, 1.0];

1 词向量处理seq2seq

1.0 获取问答及答案权重

  • Demo
import tensorflow as tf  # 0.12
import seq2seq_model
import os
import numpy as np
import math
 
PAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3
 
train_encode_vec = './data/word2vec/train_question_encode.vec'
train_decode_vec = './data/word2vec/train_answer_decode.vec'
test_encode_vec = './data/word2vec/test_question_encode.vec'
test_decode_vec = './data/word2vec/test_answer_decode.vec'
 
# 词汇表大小5000
vocabulary_encode_size = 470
vocabulary_decode_size = 470
 
buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
def read_data(question_path, answer_path, max_size=None):
	data_set = [[] for _ in buckets]
	with tf.gfile.GFile(question_path, mode="r") as question_file:
		with tf.gfile.GFile(answer_path, mode="r") as answer_file:
			question, answer = question_file.readline(), answer_file.readline()
			counter = 0
			while question and answer and (not max_size or counter < max_size):
				counter += 1
				question_ids = [int(x) for x in question.split()]
				answer_ids = [int(x) for x in answer.split()]
				answer_ids.append(EOS_ID)
				for bucket_id, (question_size, answer_size) in enumerate(buckets):
					if len(question_ids) < question_size and len(answer_ids) < answer_size:
						data_set[bucket_id].append([question_ids, answer_ids])
						break
				question, answer = question_file.readline(), answer_file.readline()
				print("question: {}, answer: {}".format(question, answer))
	return data_set
model = seq2seq_model.Seq2SeqModel(question_vocab_size=vocabulary_encode_size, answer_vocab_size=vocabulary_decode_size,
                                   buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm= 5.0,
                                   batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.97, forward_only=False)

train_set = read_data(train_encode_vec, train_decode_vec)
def get_batch(data, bucket_id):
    """Get a random batch of data from the specified bucket, prepare for step.

    To feed data in step(..) it must be a list of batch-major vectors, while
    data here contains single length-major cases. So the main logic of this
    function is to re-index data cases to be in the proper format for feeding.

    Args:
      data: 词向量列表,如[[[4,4],[5,6,8]]]
      bucket_id: 桶编号,值取自桶对话占比
    Returns:
      The triple (encoder_inputs, decoder_inputs, answer_weights) for
      the constructed batch that has the proper format to call step(...) later.
    """
    # 问题和答案的数据量:桶的话数buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
    encoder_size, decoder_size = self.buckets[bucket_id]
    # 生成问题和答案的存储器
    encoder_inputs, decoder_inputs = [], []

    # Get a random batch of encoder and decoder inputs from data,
    # pad them if needed, reverse encoder inputs and add GO to decoder.
    for _ in xrange(self.batch_size):
    	# 从问答数据集中随机选取问答
      encoder_input, decoder_input = random.choice(data[bucket_id])

      # 问题末尾添加PAD_ID并反向排序
      encoder_pad = [word_to_vec.PAD_ID] * (encoder_size - len(encoder_input))
      encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))

      # 答案添加GO_ID和PAD_ID
      decoder_pad_size = decoder_size - len(decoder_input) - 1
      decoder_inputs.append([word_to_vec.GO_ID] + decoder_input +
                            [word_to_vec.PAD_ID] * decoder_pad_size)

    # 问题,答案,权重批量数据
    batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []

    # 批量问题
    for length_idx in xrange(encoder_size):
      batch_encoder_inputs.append(
          np.array([encoder_inputs[batch_idx][length_idx]
                    for batch_idx in xrange(self.batch_size)], dtype=np.int32))

    # 批量答案
    for length_idx in xrange(decoder_size):
      batch_decoder_inputs.append(
          np.array([decoder_inputs[batch_idx][length_idx]
                    for batch_idx in xrange(self.batch_size)], dtype=np.int32))

      # 答案权重即Attention机制
      batch_weight = np.ones(self.batch_size, dtype=np.float32)
      for batch_idx in xrange(self.batch_size):
        # 若答案为PAD则权重设置为0,因为是添加的ID
        # 其他的设置为1
        if length_idx < decoder_size - 1:
          answer = decoder_inputs[batch_idx][length_idx + 1]
        if length_idx == decoder_size - 1 or answer == word_to_vec.PAD_ID:
          batch_weight[batch_idx] = 0.0
      batch_weights.append(batch_weight)
    print("encoder inputs: {}, decoder inputs: {}, answer weights: {}".format(encoder_inputs, decoder_inputs, answer_weights))
get_batch(data_set, bucket_id)
  • Result
encoder inputs: [array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      dtype=int32), array([ 0,  0, 69,  0,  0, 69,  0,  0,  0,  0,  6,  0,  0,  0,  0, 78,  0,
        0,  0,  0, 11,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0, 12,  0,  0,  0,  0, 12,  4,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0, 71, 69,  0,  0,  6,  0,  0,  0,  0,  0,  0], dtype=int32), array([  0,   0,  28,   0,   0,  28,   0,  32,   0,   0,  62,   0,   0,
         0,   0,  77,   0,   0,   8,   8,  11,   0,   9,   0,   0,   0,
         0,   0,   0,   0,  12,  10,   0, 106,   0,   0,   5,   0,   0,
         0,   5,   5, 216,   0,   0,  11,   0,   0,   0,   0,  78,   0,
         0,   7,  28,   0,   0,  62,   0,   0,   0,   0,   0,   0],
      dtype=int32), array([  0,   0,   4,   0,   0,   4,  90,  32,  20,   0,  62,   4, 267,
        20, 119,   8,  14,   0,  22,  10,  11, 246,   8,   0,   5, 119,
        90,   0,   0,  14,   5,  79, 102,   5,  51,   0,  45, 113,  37,
        90,  55,  45,  18,   0,   0,  11,   0,  84, 119,   8,  77,   0,
         0,   4,   4,   0,   0,  62,   0,   0,  20, 113,   0,  23],
      dtype=int32), array([309,  58,  27,  41,  41,  27,  31,  32,  13, 122,   4,  87,  66,
        13, 119,  22,  14,   0, 137,  31,  11, 245,  24, 217,   7, 119,
        31, 219, 212,  14,  76,  19,  13,   4,   5,   6,  44, 113, 143,
        31,   4,  44,  17,  41,   0,  11, 219,  84, 119,  22,  16, 226,
       213,  27,  27, 120, 219,   4, 244, 219,  13, 113, 309,   4],
      dtype=int32)]
decoder inputs: [array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32), array([ 69,  56,   4, 343, 168,   4, 195,   5, 177, 353,  98,   5, 438,
       249,   5,  36,  16,   5,  45,  30,  26, 387,   4, 338,  36,   5,
       195, 164,   5,  16,  95,  74, 258,   4,  72,   6, 216, 314,   5,
       195,   4, 203,   5, 168,   5,  26, 164, 443,   5, 361, 177, 352,
         5,   4,   4,  88, 164,  98,  16, 164, 249, 314,  69,   4],
      dtype=int32), array([ 12,   2, 121,   2,   2, 121, 195,   5,  78,   2,  67,   5, 193,
        14,   5,  29,  11,   5,  99,   8,  26,   6,  37,   2,   8,   5,
       195, 164,   5,  11, 107,  74, 259,   8,  84,  29, 217,   2,   5,
       195,   9, 204,   5,   2,   5,   6, 164, 190,   5, 362,  78,   2,
         5,  14, 121,   2, 164,  67, 181, 164,  14,   2,  12,  23],
      dtype=int32), array([ 10,   0, 122,   0,   0, 122,   2,   2, 104,   0,   8,   2,   2,
         2,   2,  12,   2,   2,  26,   7,   6,  17,  81,   0,  29,   2,
         2,   2,   2,   2,   8,  34,  17, 101,  95,  12,   2,   0,   2,
         2,   8, 205,   2,   0,   2,  33,   2,   2,   2, 363, 104,   0,
         2, 263, 122,   0,   2,   8,  13,   2,   2,   0,  10,  13],
      dtype=int32), array([ 55,   0,   2,   0,   0,   2,   0,   0,  77,   0,   4,   0,   0,
         0,   0,  13,   0,   0,  26,   2,  33, 183, 113,   0,  12,   0,
         0,   0,   0,   0,   6,   7,   8,   2,   2,   2,   0,   0,   0,
         0, 157, 206,   0,   0,   0,  15,   0,   0,   0, 364,  77,   0,
         0, 264,   2,   0,   0,   4, 385,   0,   0,   0,  55,  33],
      dtype=int32), array([469,   0,   0,   0,   0,   0,   0,   0,   9,   0, 111,   0,   0,
         0,   0, 360,   0,   0,  65,   0,   2,  38, 114,   0,   2,   0,
         0,   0,   0,   0,  63,   2,  31,   0,   0,   0,   0,   0,   0,
         0,   6,  31,   0,   0,   0,   2,   0,   0,   0, 176,   9,   0,
         0,  32,   0,   0,   0, 111, 165,   0,   0,   0, 469,   4],
      dtype=int32), array([  9,   0,   0,   0,   0,   0,   0,   0, 178,   0, 157,   0,   0,
         0,   0, 153,   0,   0,   2,   0,   0,   9, 198,   0,   0,   0,
         0,   0,   0,   0, 119,   0, 260,   0,   0,   0,   0,   0,   0,
         0,  71, 116,   0,   0,   0,   0,   0,   0,   0,   2, 178,   0,
         0,   2,   0,   0,   0, 157, 386,   0,   0,   0,   9,  13],
      dtype=int32), array([ 27,   0,   0,   0,   0,   0,   0,   0,   2,   0,   2,   0,   0,
         0,   0,   2,   0,   0,   0,   0,   0,  27,   4,   0,   0,   0,
         0,   0,   0,   0,   2,   0,   2,   0,   0,   0,   0,   0,   0,
         0,   8, 117,   0,   0,   0,   0,   0,   0,   0,   0,   2,   0,
         0,   0,   0,   0,   0,   2,   2,   0,   0,   0,  27, 180],
      dtype=int32), array([  2,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0, 134,   2,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   2,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   2, 180],
      dtype=int32), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
      dtype=int32)]
answer weights: [array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32), array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32), array([1., 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
       0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1.], dtype=float32), array([1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0.,
       0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
       1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1.,
       0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 1.], dtype=float32), array([1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0.,
       0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0.,
       0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1.,
       0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1.], dtype=float32), array([1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0.,
       0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
       0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1.], dtype=float32), array([1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1.], dtype=float32), array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
  • Analysis
    (1) 对问题和答案的向量重新整理,符合桶尺寸则保持对话尺寸,若不符合桶设定尺寸,则进行填充处理,问题使用PAD_ID填充,答案使用GO_IDPAD_ID填充;
    (2) 对问题和答案向量填充整理后,使用Attention机制,对答案进行权重分配,答案中的PAD_ID权重为0,其他对应的为1;
    (3) get_batch()处理词向量,返回问题,答案答案权重数据;返回结果如上结果:encoder_inputs, decoder_inputs, answer_weights.

1.2 处理问答及答案权重

  • Demo
def step(self, session, encoder_inputs, decoder_inputs, answer_weights,
           bucket_id, forward_only):
    """Run a step of the model feeding the given inputs.

    参数:
      session: tensorflow 会话.
      encoder_inputs: 问题向量列表
      decoder_inputs: 回答向量列表
      answer_weights: 答案权重列表
      bucket_id: 桶编号which bucket of the model to use.
      forward_only: 前向或反向运算标志位
    返回:
    	一个由梯度范数组成的三重范数(如果不使用反向传播,则为无)。
平均困惑度和输出
    Raises:
      ValueError: if length of encoder_inputs, decoder_inputs, or
        answer_weights disagrees with bucket size for the specified bucket_id.
    """
    # 问答匹配桶尺寸
    encoder_size, decoder_size = self.buckets[bucket_id]
    if len(encoder_inputs) != encoder_size:
      raise ValueError("Encoder length must be equal to the one in bucket,"
                       " %d != %d." % (len(encoder_inputs), encoder_size))
    if len(decoder_inputs) != decoder_size:
      raise ValueError("Decoder length must be equal to the one in bucket,"
                       " %d != %d." % (len(decoder_inputs), decoder_size))
    if len(answer_weights) != decoder_size:
      raise ValueError("Weights length must be equal to the one in bucket,"
                       " %d != %d." % (len(answer_weights), decoder_size))

    # Input feed: encoder inputs, decoder inputs, answer_weights, as provided.
    input_feed = {}
    for l in xrange(encoder_size):
      input_feed[self.encoder_inputs[l].name] = encoder_inputs[l]
    for l in xrange(decoder_size):
      input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
      input_feed[self.answer_weights[l].name] = answer_weights[l]

    # Since our answers are decoder inputs shifted by one, we need one more.
    last_answer = self.decoder_inputs[decoder_size].name
    input_feed[last_answer] = np.zeros([self.batch_size], dtype=np.int32)

    # Output feed: depends on whether we do a backward step or not.
    if not forward_only:
      output_feed = [self.updates[bucket_id],  # Update Op that does SGD.
                     self.gradient_norms[bucket_id],  # Gradient norm.
                     self.losses[bucket_id]]  # Loss for this batch.
    else:
      output_feed = [self.losses[bucket_id]]  # Loss for this batch.
      for l in xrange(decoder_size):  # Output logits.
        output_feed.append(self.outputs[bucket_id][l])

    outputs = session.run(output_feed, input_feed)
    if not forward_only:
      return outputs[1], outputs[2], None  # Gradient norm, loss, no outputs.
    else:
      return None, outputs[0], outputs[1:]  # No gradient norm, loss, outputs.
  • Result

反向传播

gradient norm: 0.14181189239025116, step loss: 0.024454593658447266

前向传播
部分数据

[array([[-8.61980557e-01,  7.83682689e-02, -6.27610302e+00,
         4.22952592e-01, -6.75087309e+00,  4.07743502e+00,
         6.32381868e+00,  7.75186586e+00, -9.86352444e+00,
        ...
        -9.14800549e+00,  8.39417934e+00, -6.37120008e+00,
        -3.18058801e+00, -4.23562908e+00, -3.24618030e+00,
         4.26430941e+00, -4.64457321e+00, -2.33847499e+00,
         1.18018208e+01, -3.61384606e+00]], dtype=float32), array......
  • Analysis
    (1) 根据输入的问答向量列表,分配语料桶,处理问答向量列表,并生成新的输入字典(dict), input_feed = {};
    (2) 输出字典(dict), ouput_feed = {},根据是否使用反向传播获得参数,使用反向传播,output_feed存储更新的梯度范数,损失,不使用反向传播,则只存储损失;
    (3) 最终的输出为分两种情况,使用反向传播,返回梯度范数,损失,如反向传播不使用反向传播,返回损失和输出的向量(用于加载模型,测试效果),如前向传播;

1.3 seq2seq模型

class Seq2SeqModel(object):
  def __init__(self, source_vocab_size, target_vocab_size, buckets, size,
               num_layers, max_gradient_norm, batch_size, learning_rate,
               learning_rate_decay_factor, use_lstm=False,
               num_samples=512, forward_only=False):
    self.source_vocab_size = source_vocab_size
    self.target_vocab_size = target_vocab_size
    self.buckets = buckets
    self.batch_size = batch_size
    self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
    self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor)
    self.global_step = tf.Variable(0, trainable=False)

    output_projection = None
    softmax_loss_function = None
    
    if num_samples > 0 and num_samples < self.target_vocab_size:
      w = tf.get_variable("proj_w", [size, self.target_vocab_size])
      w_t = tf.transpose(w)
      b = tf.get_variable("proj_b", [self.target_vocab_size])
      output_projection = (w, b)

      def sampled_loss(labels,logits):
        labels = tf.reshape(labels, [-1, 1])
        return tf.nn.sampled_softmax_loss(w_t, b, labels,logits,num_samples,self.target_vocab_size)
      
      softmax_loss_function = sampled_loss
    single_cell = tf.contrib.rnn.GRUCell(size)
    if use_lstm:
      single_cell = tf.contrib.rnn.BasicLSTMCell(size)
    cell = single_cell
    if num_layers > 1:
      cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)

    def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
      tem_cell=copy.deepcopy(cell)
      return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
          encoder_inputs, decoder_inputs, tem_cell,
          num_encoder_symbols=source_vocab_size,
          num_decoder_symbols=target_vocab_size,
          embedding_size=size,
          output_projection=output_projection,
          feed_previous=do_decode)

   
    self.encoder_inputs = []
    self.decoder_inputs = []
    self.target_weights = []
    for i in xrange(buckets[-1][0]):  # Last bucket is the biggest one.
      self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
                                                name="encoder{0}".format(i)))
    for i in xrange(buckets[-1][1] + 1):
      self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
                                                name="decoder{0}".format(i)))
      self.target_weights.append(tf.placeholder(tf.float32, shape=[None],
                                                name="weight{0}".format(i)))

    targets = [self.decoder_inputs[i + 1]
               for i in xrange(len(self.decoder_inputs) - 1)]

    # Training outputs and losses.
    if forward_only:
      self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
          self.encoder_inputs, self.decoder_inputs, targets,
          self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True),
          softmax_loss_function=softmax_loss_function)
      
      if output_projection is not None:
        for b in xrange(len(buckets)):
          self.outputs[b] = [
              tf.matmul(output, output_projection[0]) + output_projection[1]
              for output in self.outputs[b]
          ]
    else:
      self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
          self.encoder_inputs, self.decoder_inputs, targets,
          self.target_weights, buckets,
          lambda x, y: seq2seq_f(x, y, False),
          softmax_loss_function=softmax_loss_function)

    params = tf.trainable_variables()
    if not forward_only:
      self.gradient_norms = []
      self.updates = []
      opt = tf.train.GradientDescentOptimizer(self.learning_rate)
      for b in xrange(len(buckets)):
        gradients = tf.gradients(self.losses[b], params)
        clipped_gradients, norm = tf.clip_by_global_norm(gradients,
                                                         max_gradient_norm)
        self.gradient_norms.append(norm)
        self.updates.append(opt.apply_gradients(
            zip(clipped_gradients, params), global_step=self.global_step))

    self.saver = tf.train.Saver(tf.all_variables())

2 训练&保存模型

model = seq2seq_model.Seq2SeqModel(question_vocab_size=vocabulary_encode_size, answer_vocab_size=vocabulary_decode_size,
                                   buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm= 5.0,
                                   batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.97, forward_only=False)

# 使用GPU配置
# # config = tf.ConfigProto()
# # config.gpu_options.allocator_type = 'BFC'  # 防止 out of memory
if __name__ == "__main__":

	with tf.Session() as sess:
	 
	# with tf.Session(config=config) as sess:
		# 检查是否有已存在的训练模型
		# 有模型则获取模型轮数,接着训练
		# 没有模型则从开始训练
		ckpt = tf.train.get_checkpoint_state('./models')
		if ckpt != None:
			train_turn = ckpt.model_checkpoint_path.split('-')[1]
			print("model path: {}, train turns: {}".format(ckpt.model_checkpoint_path, train_turn))
			model.saver.restore(sess, ckpt.model_checkpoint_path)
			total_step = int(train_turn)
		else:
			sess.run(tf.global_variables_initializer())
			total_step = 0
	 
		train_set = read_data(train_encode_vec, train_decode_vec)
		test_set = read_data(test_encode_vec, test_decode_vec)
	 
		train_bucket_sizes = [len(train_set[b]) for b in range(len(buckets))]
		# print("train bucket sizes: {}".format(train_bucket_sizes))
		train_total_size = float(sum(train_bucket_sizes))
		train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]
	 
		loss = 0.0
		# total_step = int(train_turn)
		previous_losses = []
		# 一直训练,每过一段时间保存一次模型
		while True:
			random_number_01 = np.random.random_sample()
			# get minimum i as bucket id when value > randmom value
			bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
	 
			encoder_inputs, decoder_inputs, answer_weights = model.get_batch(train_set, bucket_id)
			# print("encoder inputs: {}, decoder inputs: {}, answer weights: {}".format(encoder_inputs, decoder_inputs, answer_weights))
			gradient_norm, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, answer_weights, bucket_id, False)
			print("gradient norm: {}, step loss: {}".format(gradient_norm, step_loss))
			loss += step_loss / 500
			total_step += 1
	 
			print("total step: {}".format(total_step))
			if total_step % 500 == 0:
				print("global step: {}, learning rate: {}, loss: {}".format(model.global_step.eval(), model.learning_rate.eval(), loss))
	 
				# 如果模型没有得到提升,减小learning rate
				if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
					sess.run(model.learning_rate_decay_op)
				previous_losses.append(loss)
				# 保存模型
				checkpoint_path = "./models/chatbot_seq2seq.ckpt"
				model.saver.save(sess, checkpoint_path, global_step=model.global_step)
				loss = 0.0
				# 使用测试数据评估模型
				for bucket_id in range(len(buckets)):
					if len(test_set[bucket_id]) == 0:
						continue
					encoder_inputs, decoder_inputs, answer_weights = model.get_batch(test_set, bucket_id)
					_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, answer_weights, bucket_id, True)
					eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
					print("bucket id: {}, eval ppx: {}".format(bucket_id, eval_ppx))

  • Result
total step: 1
total step: 2
total step: 3
total step: 4
total step: 5
...
global step: 500, learning rate: 0.5, loss: 2.574068747580052
bucket id: 0, eval ppx: 14176.588030763274
bucket id: 1, eval ppx: 3650.0026667220773
bucket id: 2, eval ppx: 4458.454110999805
bucket id: 3, eval ppx: 5290.083583183104

3 载入模型&测试

  • Demo
import tensorflow as tf  # 0.12
import seq2seq_model
import os
import numpy as np
 
PAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3
 
train_encode_vocabulary = './data/word2vec/train_question_encode_vocabulary'
train_decode_vocabulary = './data/word2vec/train_answer_decode_vocabulary'

# train_encode_vec = './data/word2vec/train_question_encode.vec'
# train_decode_vec = './data/word2vec/train_answer_decode.vec'
# test_encode_vec = './data/word2vec/test_question_encode.vec'
# test_decode_vec = './data/word2vec/test_answer_decode.vec'

 
def read_vocabulary(input_file):
	tmp_vocab = []
	with open(input_file, "r") as f:
		tmp_vocab.extend(f.readlines())
	tmp_vocab = [line.strip() for line in tmp_vocab]
	vocab = dict([(x, y) for (y, x) in enumerate(tmp_vocab)])
	return vocab, tmp_vocab
 
vocab_en, _, = read_vocabulary(train_encode_vocabulary)
_, vocab_de, = read_vocabulary(train_decode_vocabulary)
 
# 词汇表大小5000
vocabulary_encode_size = 470
vocabulary_decode_size = 470
 
buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
layer_size = 256  # 每层大小
num_layers = 3   # 层数
batch_size =  1
 
# model = seq2seq_model.Seq2SeqModel(source_vocab_size=vocabulary_encode_size, target_vocab_size=vocabulary_decode_size,
#                                    buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm= 5.0,
#                                    batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.99, forward_only=True)

model = seq2seq_model.Seq2SeqModel(question_vocab_size=vocabulary_encode_size, answer_vocab_size=vocabulary_decode_size,
                                   buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm= 5.0,
                                   batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.99, forward_only=True)
# 训练取一个数据
model.batch_size = 1

with tf.Session() as sess:
	# 载入模型
	ckpt = tf.train.get_checkpoint_state('./models')
	if ckpt != None:
		print(ckpt.model_checkpoint_path)
		model.saver.restore(sess, ckpt.model_checkpoint_path)
	else:
		print("没找到模型")
 
	while True:
		input_string = input('me > ')
		# 退出
		if input_string == 'quit':
			exit()
 
		input_string_vec = []
		for words in input_string.strip():
			input_string_vec.append(vocab_en.get(words, UNK_ID))
		bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])
		encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)
		_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
		print("output logits: {}".format(output_logits))
		outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
		if EOS_ID in outputs:
			outputs = outputs[:outputs.index(EOS_ID)]
 
		response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])
		print('AI > ' + response)
  • Result
    在这里插入图片描述
图3.1 对话测试

4 总结

(1) 该聊天机器人使用bucket桶结构,即指定问答数据的长度,匹配符合的桶,在桶中进行存取数据;
(2) 该seq2seq模型使用Tensorflow时,未能建立独立标识的图结构,在进行后台封装过程中出现图为空的现象;


[参考文献]
[1]http://blog.topspeedsnail.com/archives/10735/comment-page-1#comment-1161%E3%80%82
[2]https://github.com/tensorflow/nmt
[3]https://github.com/tensorflow/tensorflow/blob/b19d6657070bbf1df5706195a0bf3a92cbf371fc/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
[4]https://github.com/tensorflow/tensorflow/tree/b19d6657070bbf1df5706195a0bf3a92cbf371fc/tensorflow/contrib/seq2seq/python/ops


  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值