自然语言处理:分词(断词)与关键词提取方法(一)
自然语言处理:scrapy爬取关键词信息(二)
自然语言处理:问答语料生成词汇表,词转向量(三)
1 处理数据集
- Demo
import tensorflow as tf
# 语料库长度桶结构
# (5, 10): 5问题长度,10回答长度
buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
def read_data(question_path, answer_path, max_size=None):
# 提取桶结构
data_set = [[] for _ in buckets]
with tf.gfile.GFile(question_path, mode="r") as question_file:
with tf.gfile.GFile(answer_path, mode="r") as answer_file:
question, answer = question_file.readline(), answer_file.readline()
counter = 0
while question and answer and (not max_size or counter < max_size):
counter += 1
question_ids = [int(x) for x in question.split()]
answer_ids = [int(x) for x in answer.split()]
answer_ids.append(EOS_ID)
for bucket_id, (question_size, answer_size) in enumerate(buckets):
if len(question_ids) < question_size and len(answer_ids) < answer_size:
data_set[bucket_id].append([question_ids, answer_ids])
break
question, answer = question_file.readline(), answer_file.readline()
return data_set
# 词向量数据集
train_set = read_data(train_encode_vec, train_decode_vec)
print("Train set: {}".format(train_set))
# 每个桶中对话数量,一问一答为一次完整对话
train_bucket_sizes = [len(train_set[b]) for b in range(len(buckets))]
print("train bucket sizes: {}".format(train_bucket_sizes))
# 总对话数量
total_size = float(sum(train_bucket_sizes))
print("total size: {}".format(total_size))
for i in train_set:
print("data: {}".format(i))
print("lenght of data set: {}".format(len(train_set)))
# 每个桶中对话数量占比
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / total_size for i in range(len(train_bucket_sizes))]
print("train buckets scale: {}".format(train_buckets_scale))
- Result
[[[[7, 5], [36, 8, 29, 12, 2]], [[24, 8, 9], [4, 37, 81, 113, 114, 198, 4, 2]], [[154, 155], [30, 30, 4, 38, 82, 62, 2]], [[44, 45, 5, 12], [203, 204, 205, 206, 31, 116, 117, 2]], [[44, 45, 5, 12], [8, 214, 215, 7, 63, 119, 15, 2]], [[44, 45, 5, 12], [216, 217, 2]], [[27, 4, 28, 69], [4, 121, 122, 2]], [[162, 98, 163], [236, 237, 15, 2]], [[4], [2]], [[11, 11, 11], [243, 244, 4, 2]], [[13, 20], [249, 14, 2]], [[13, 102], [258, 259, 17, 8, 31, 260, 2]], [[5, 51], [72, 84, 95, 2]], [[27, 4, 7, 71], [4, 14, 263, 264, 32, 2]], [[28, 48, 72], [14, 21, 6, 15, 2]], [[182, 29, 6], [4, 37, 92, 54, 7, 2]], [[6, 7, 28, 183], [28, 281, 282, 2]], [[4, 5, 106], [4, 8, 101, 2]], [[4, 55, 5], [4, 9, 8, 157, 6, 71, 8, 2]], [[109, 110, 111], [20, 309, 74, 2]], [[19, 191, 14], [6, 9, 18, 4, 2]], [[76, 5, 12], [95, 107, 8, 6, 63, 119, 2]], [[14, 14], [16, 11, 2]], [[11, 11, 11], [26, 6, 33, 15, 2]], [[194, 195, 13, 23], [5, 5, 2]], [[79, 196], [5, 5, 2]], [[], [5, 5, 2]], [[113, 113], [314, 2]], [[40, 4, 23], [6, 20, 87, 2]], [[11, 11], [26, 6, 33, 2]], [[58], [56, 2]], [[6], [6, 29, 12, 2]], [[117, 210, 58], [5, 5, 2]], [[212], [5, 5, 2]], [[119, 119], [5, 5, 2]], [[], [5, 5, 2]], [[213], [5, 5, 2]], [[], [5, 5, 2]], [[], [5, 5, 2]], [[], [5, 5, 2]], [[214, 215], [20, 9, 10, 17, 335, 336, 337, 2]], [[17, 18, 216, 4], [5, 5, 2]], [[217], [338, 2]], [[218], [339, 340, 341, 342, 2]], [[219], [164, 164, 2]], [[41], [343, 2]], [[120], [88, 2]], [[224], [167, 167, 2]], [[41], [168, 2]], [[226], [352, 2]], [[122], [353, 2]], [[], [5, 5, 2]], [[32, 32, 32], [5, 5, 2]], [[22, 8, 77, 78], [36, 29, 12, 13, 360, 153, 2]], [[22, 8], [361, 362, 363, 364, 176, 2]], [[16, 77, 78], [177, 78, 104, 77, 9, 178, 2]], [[19, 79, 10], [74, 74, 34, 7, 2]], [[87, 4], [5, 5, 2]], [[4, 23], [4, 23, 13, 33, 4, 13, 180, 180, 2]], [[4, 129, 129], [378, 182, 379, 2]], [[244], [16, 181, 13, 385, 165, 386, 2]], [[245, 246], [387, 6, 17, 183, 38, 9, 27, 134, 2]], [[31], [404, 141, 2]], [[11, 11, 11, 11], [26, 26, 6, 33, 2]], [[4, 62, 62, 6], [98, 67, 8, 4, 111, 157, 2]], [[55, 56, 16], [13, 2]], [[16], [23, 2]], [[56], [23, 23, 2]], [[4, 23], [2]], [[4, 140, 141, 8], [4, 123, 427, 428, 429, 2]], [[62, 62], [111, 111, 2]], [[4, 59, 264, 9], [17, 432, 2]], [[66, 267], [438, 193, 2]], [[268, 98, 269], [439, 440, 441, 2]], [[84, 84], [443, 190, 2]], [[137, 22, 8], [45, 99, 26, 26, 65, 2]], [[143, 37], [5, 5, 2]], [[13, 20], [177, 78, 104, 77, 9, 178, 2]], [[31, 90], [195, 195, 2]], [[15, 290, 291], [9, 4, 8, 100, 55, 7, 2]], [[147, 31, 148, 149], [69, 12, 10, 2]], [[24, 8, 144], [9, 43, 44, 2]], [[147], [468, 109, 14, 21, 6, 2]], [[31, 10, 8], [30, 8, 7, 2]], [[91, 4, 23], [4, 23, 33, 2]], [[309], [69, 12, 10, 55, 469, 9, 27, 2]], [[35, 57, 83, 94], [27, 46, 27, 46, 2]]], [[[4, 95, 66, 12, 8], [199, 115, 9, 8, 4, 8, 200, 201, 202, 2]], [[30, 8, 156], [207, 208, 209, 118, 38, 210, 9, 17, 83, 7, 211, 15, 2]], [[44, 45, 5, 12], [8, 4, 212, 213, 84, 18, 7, 16, 11, 7, 63, 85, 15, 2]], [[13, 96, 157, 4, 16, 67, 34, 26, 8], [63, 218, 8, 64, 7, 2]], [[35, 4, 19, 36, 10], [4, 86, 224, 6, 20, 87, 7, 2]], [[4, 28, 97, 6, 25, 159], [6, 225, 67, 226, 227, 228, 2]], [[4, 5, 160, 161, 14], [30, 30, 2]], [[99, 46, 5, 12, 10, 47, 164], [8, 4, 20, 238, 49, 239, 240, 7, 68, 241, 2]], [[4, 100, 7, 100], [6, 50, 4, 24, 50, 31, 83, 50, 125, 4, 126, 242, 2]], [[46, 165, 166, 167, 168], [245, 246, 247, 129, 248, 2]], [[28, 169, 170, 4, 28, 48, 8], [14, 90, 130, 131, 48, 132, 70, 90, 21, 133, 133, 2]], [[4, 10, 180, 181, 37, 5, 12], [277, 148, 99, 7, 278, 279, 280, 2]], [[4, 103, 27, 4, 184, 6], [4, 23, 13, 283, 284, 2]], [[15, 185, 186], [9, 19, 81, 113, 4, 286, 6, 54, 74, 96, 2]], [[4, 5, 105, 14, 51], [85, 101, 287, 7, 288, 289, 290, 150, 291, 2]], [[4, 188, 189, 53, 9], [299, 300, 7, 6, 301, 105, 152, 10, 2]], [[4, 38, 10, 76, 5, 12], [302, 303, 304, 2]], [[109, 110, 111, 5, 69, 190], [17, 39, 37, 308, 7, 64, 11, 2]], [[192, 63, 7, 193, 8], [5, 5, 2]], [[4, 16, 68, 34, 26, 56], [6, 158, 52, 32, 2]], [[15, 17, 40, 4, 19, 36], [6, 91, 322, 323, 2]], [[80, 29, 4], [11, 34, 162, 58, 24, 330, 6, 331, 332, 11, 34, 2]], [[11, 11], [51, 45, 73, 10, 163, 333, 334, 16, 11, 2]], [[206, 81, 13, 20, 9, 116, 207, 10], [5, 5, 2]], [[83], [16, 11, 13, 29, 12, 165, 344, 31, 108, 14, 345, 2]], [[82, 7, 31, 43, 9], [4, 76, 7, 8, 41, 346, 2]], [[225], [156, 349, 350, 138, 76, 351, 139, 38, 23, 13, 34, 2]], [[59, 7, 50, 4], [49, 49, 4, 42, 10, 172, 173, 4, 47, 2]], [[4, 53, 21, 228, 22, 8], [365, 366, 367, 2]], [[6, 49, 18, 229, 26], [5, 5, 2]], [[6, 49, 18, 67, 34, 26], [142, 7, 27, 27, 176, 368, 86, 2]], [[87, 59, 22, 8, 235, 236, 10, 128, 128], [5, 5, 2]], [[6, 237, 58, 15, 15], [372, 179, 373, 179, 374, 2]], [[243, 33, 6, 64, 130, 88, 8], [61, 7, 2]], [[4, 10, 130, 88, 5, 12], [380, 381, 382, 383, 78, 384, 2]], [[15, 42, 27, 4, 247], [30, 2]], [[4, 132, 61, 24, 8, 30, 8, 248], [6, 28, 25, 71, 392, 89, 15, 2]], [[4, 10, 132, 28, 69], [14, 36, 184, 109, 393, 75, 7, 394, 395, 2]], [[13, 96, 4, 5, 106, 80, 5, 14], [4, 8, 185, 101, 97, 2]], [[133, 14, 52, 5, 22, 8], [79, 79, 2]], [[134, 134, 249, 65, 33], [79, 79, 7, 186, 82, 85, 2]], [[251, 88, 7, 252, 253], [16, 11, 2]], [[135, 4], [58, 187, 4, 46, 187, 175, 10, 24, 56, 6, 130, 2]], [[254, 73, 9, 135, 50, 4], [49, 49, 4, 42, 10, 172, 173, 4, 47, 2]], [[7, 123, 124], [102, 102, 6, 7, 55, 47, 6, 43, 44, 128, 403, 7, 2]], [[4, 23, 4, 255, 79, 63, 5, 4, 23], [4, 8, 405, 406, 20, 407, 2]], [[4, 5, 14, 52], [9, 163, 18, 7, 6, 8, 4, 25, 408, 22, 2]], [[6, 18, 256, 4, 257, 133], [16, 11, 4, 42, 10, 188, 188, 2]], [[126, 127, 258, 9, 4], [11, 34, 409, 185, 95, 107, 61, 410, 411, 7, 2]], [[4, 5, 25, 136, 14], [16, 11, 4, 18, 6, 2]], [[4, 38, 24, 8, 85, 55, 137, 74, 75], [412, 103, 413, 48, 2]], [[6, 7, 74, 75, 4, 9], [4, 162, 58, 24, 9, 414, 415, 6, 2]], [[4, 5, 25, 136, 138, 138, 10, 14], [416, 417, 418, 8, 39, 116, 117, 2]], [[131, 259, 9, 42, 4], [6, 71, 45, 99, 25, 7, 19, 419, 420, 8, 9, 8, 2]], [[56, 16, 81, 81, 4], [421, 189, 182, 186, 97, 2]], [[4, 5, 19, 39, 262, 8], [4, 8, 31, 189, 426, 2]], [[15, 42, 4, 263, 9, 118, 89], [36, 8, 430, 431, 7, 2]], [[30, 7, 5, 91, 60], [4, 76, 42, 10, 442, 9, 2]], [[21, 33, 4, 5, 270, 271, 54, 142, 10], [4, 21, 6, 10, 2]], [[24, 8, 116, 272, 143, 37], [57, 57, 28, 170, 444, 7, 445, 446, 447, 72, 23, 56, 2]], [[6, 7, 40, 4, 40, 40], [6, 59, 4, 59, 6, 59, 4, 59, 9, 59, 2]], [[107, 108, 29, 6, 21], [6, 19, 4, 69, 12, 448, 152, 46, 2]], [[273, 4, 10, 144, 274], [98, 67, 4, 118, 193, 103, 52, 37, 2]], [[90, 284, 285, 286, 29, 6, 287], [459, 6, 7, 460, 96, 36, 12, 2]], [[16, 92, 288, 9, 7, 289, 18, 9], [5, 5, 2]], [[31, 292, 43, 4, 30, 85, 92, 6], [461, 18, 6, 462, 2]], [[4, 53, 293, 89, 146, 25, 86, 6, 294], [5, 5, 2]], [[302, 93, 303, 41, 94, 57, 57, 41, 150], [110, 35, 196, 2]], [[305, 306, 13, 20, 122, 94, 307, 83, 308], [5, 5, 2]]], [[[151, 43, 152, 32, 25, 153, 63, 64, 15, 65, 33], [30, 61, 7, 2]], [[27, 4, 16, 68, 34, 26, 8], [23, 13, 32, 4, 219, 19, 16, 11, 17, 39, 11, 120, 18, 4, 65, 2]], [[4, 158, 5, 12], [4, 25, 25, 8, 220, 221, 123, 66, 222, 223, 7, 16, 11, 6, 15, 2]], [[4, 19, 36, 10], [124, 88, 48, 124, 88, 48, 229, 230, 10, 231, 232, 8, 40, 89, 14, 233, 7, 234, 235, 10, 2]], [[11, 11], [26, 6, 33, 6, 26, 26, 10, 4, 19, 69, 12, 127, 128, 6, 41, 41, 2]], [[171, 4, 7, 49, 6, 6, 17, 7, 101, 4, 172], [36, 4, 24, 91, 51, 11, 134, 86, 2]], [[174, 175, 5, 27, 6, 7, 53, 47, 70, 176, 4, 47, 70, 17, 71], [4, 42, 10, 4, 143, 38, 9, 50, 10, 71, 94, 2]], [[4, 48, 72], [37, 121, 122, 261, 262, 53, 9, 43, 44, 21, 144, 65, 6, 14, 144, 65, 32, 2]], [[4, 54, 177, 37, 9], [9, 19, 82, 62, 145, 11, 145, 11, 53, 8, 20, 265, 96, 20, 266, 97, 2]], [[49, 6, 48, 72], [6, 8, 39, 45, 267, 21, 6, 73, 268, 20, 269, 41, 41, 41, 2]], [[4, 64, 38, 178, 179, 51], [98, 67, 10, 270, 271, 272, 273, 146, 274, 147, 146, 275, 147, 276, 2]], [[6, 21, 73, 4, 104, 25, 39, 17, 74, 75, 4, 9], [285, 100, 7, 149, 2]], [[6, 24, 8, 187, 4, 15, 26], [151, 294, 295, 21, 75, 10, 296, 297, 136, 137, 24, 61, 298, 10, 2]], [[4, 23], [315, 56, 6, 159, 316, 317, 318, 319, 6, 24, 320, 160, 34, 160, 34, 321, 55, 10, 57, 57, 57, 2]], [[197, 5, 19, 36], [4, 20, 87, 324, 325, 10, 6, 91, 28, 13, 326, 62, 9, 62, 4, 93, 28, 327, 6, 328, 161, 329, 2]], [[4, 19, 36, 7, 198, 199, 6, 18, 7, 114, 200, 4, 35], [5, 5, 2]], [[201, 29, 4, 202, 203, 57, 204, 205, 4, 17, 97, 13, 20, 13, 20, 115, 43, 9], [5, 5, 2]], [[37, 82, 208, 9, 209, 7, 73, 4, 29, 6, 54, 10, 9], [5, 5, 2]], [[55, 95, 66, 103, 118, 9, 211, 4, 117, 58], [5, 5, 2]], [[220, 4, 22, 8, 32, 221, 222, 223, 17, 5, 21, 4, 7, 121], [5, 5, 2]], [[21, 4, 7, 121], [9, 19, 46, 347, 9, 27, 4, 166, 17, 115, 166, 9, 19, 348, 10, 4, 46, 2]], [[59, 13, 20, 35, 9, 84], [4, 38, 21, 54, 169, 12, 354, 4, 170, 54, 10, 171, 171, 47, 2]], [[15, 123, 124, 99, 46, 17, 30, 85, 42, 6, 86, 9], [5, 5, 2]], [[6, 18, 7, 18, 65, 227, 35, 21, 6, 68, 34, 26], [175, 2]], [[125, 86, 230, 67, 231, 42, 126, 127, 232, 233, 234], [369, 370, 371, 2]], [[104, 25, 238, 101, 4, 38, 60, 239, 240, 10, 7, 71], [23, 17, 375, 376, 6, 127, 58, 2]], [[4, 82, 241, 242, 9], [51, 11, 73, 4, 78, 181, 377, 4, 73, 51, 11, 9, 75, 151, 2]], [[4, 5, 7, 5, 131], [4, 8, 9, 8, 105, 114, 388, 7, 161, 75, 72, 8, 105, 389, 50, 7, 390, 391, 58, 2]], [[4, 250, 105, 5, 12], [8, 17, 39, 126, 396, 397, 398, 399, 400, 13, 184, 109, 401, 7, 20, 402, 64, 150, 110, 35, 70, 80, 2]], [[6, 5, 139, 39, 61], [120, 12, 422, 148, 47, 29, 12, 53, 9, 92, 54, 169, 7, 125, 129, 94, 17, 83, 2]], [[46, 47, 10, 260, 4, 5, 261, 39, 61, 80, 5, 139, 39, 61], [423, 424, 4, 154, 77, 9, 190, 6, 7, 425, 2]], [[4, 265, 90, 9], [8, 32, 8, 32, 31, 108, 66, 433, 10, 31, 108, 66, 18, 16, 11, 10, 2]], [[275, 276, 4, 145, 10, 60, 7, 277, 278, 279, 280], [4, 8, 449, 450, 94, 29, 12, 53, 81, 9, 451, 4, 2]], [[4, 21, 30, 5, 281, 282, 54, 142, 7, 145, 283, 10], [5, 5, 2]], [[92, 4, 148, 149, 102, 115, 9, 140, 33, 295, 89, 146, 296], [5, 5, 2]], [[297, 141, 298, 299, 6, 300, 114, 16, 14, 301, 16, 52], [5, 5, 2]]], [[[4, 18, 70, 50, 6, 51, 173, 52], [68, 135, 8, 250, 51, 11, 7, 42, 251, 252, 253, 136, 137, 21, 21, 14, 6, 68, 135, 7, 138, 254, 139, 14, 255, 140, 141, 32, 256, 9, 92, 93, 142, 52, 257, 10, 2]], [[24, 8, 107, 108], [6, 102, 22, 4, 18, 22, 6, 13, 22, 103, 104, 22, 4, 18, 22, 6, 13, 22, 292, 22, 293, 22, 22, 22, 2]], [[22, 8, 77, 78], [40, 8, 14, 76, 28, 9, 305, 64, 68, 7, 55, 153, 306, 8, 154, 155, 155, 7, 156, 307, 52, 37, 84, 18, 106, 53, 9, 43, 44, 6, 72, 100, 8, 45, 15, 2]], [[4, 38, 10, 27, 112, 76, 5, 12], [45, 310, 89, 14, 14, 6, 311, 90, 77, 140, 7, 11, 24, 8, 6, 7, 40, 6, 19, 312, 40, 27, 313, 40, 18, 40, 43, 44, 10, 149, 2]], [[17, 5, 6, 87, 5, 50, 4, 9], [4, 158, 52, 9, 8, 355, 356, 28, 66, 357, 27, 358, 7, 174, 6, 13, 29, 12, 21, 359, 6, 19, 93, 28, 43, 44, 7, 174, 2]], [[266, 29, 6, 125, 112, 91, 60], [434, 6, 25, 159, 191, 24, 106, 39, 435, 192, 15, 6, 25, 24, 56, 6, 106, 12, 183, 191, 15, 6, 25, 8, 107, 25, 47, 436, 6, 33, 15, 437, 2]], [[13, 20], [16, 452, 6, 453, 194, 12, 143, 454, 4, 7, 194, 192, 19, 60, 10, 60, 10, 60, 10, 60, 15, 60, 455, 456, 457, 458, 2]], [[304, 93, 41, 150, 93, 120], [196, 463, 464, 112, 112, 110, 465, 48, 466, 35, 467, 80, 70, 132, 80, 70, 197, 197, 35, 112, 35, 131, 80, 35, 168, 35, 2]]]]
train bucket sizes: [87, 69, 36, 8]
total size: 200.0
lenght of data set: 4
train buckets scale: [0.435, 0.78, 0.96, 1.0]
- Analysis
(1) 设定4个桶结构,即将问答分成4个部分,每个同种存放对应的问答数据集[87, 69, 36, 8]四个桶中分别有87组对话,69组对话,36组对话,8组对话;
(2) 训练词数据集符合桶长度则输入对应值,不符合桶长度,则为空;
(3) 对话数量占比:[0.435, 0.78, 0.96, 1.0];
1 词向量处理seq2seq
1.0 获取问答及答案权重
- Demo
import tensorflow as tf # 0.12
import seq2seq_model
import os
import numpy as np
import math
PAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3
train_encode_vec = './data/word2vec/train_question_encode.vec'
train_decode_vec = './data/word2vec/train_answer_decode.vec'
test_encode_vec = './data/word2vec/test_question_encode.vec'
test_decode_vec = './data/word2vec/test_answer_decode.vec'
# 词汇表大小5000
vocabulary_encode_size = 470
vocabulary_decode_size = 470
buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
def read_data(question_path, answer_path, max_size=None):
data_set = [[] for _ in buckets]
with tf.gfile.GFile(question_path, mode="r") as question_file:
with tf.gfile.GFile(answer_path, mode="r") as answer_file:
question, answer = question_file.readline(), answer_file.readline()
counter = 0
while question and answer and (not max_size or counter < max_size):
counter += 1
question_ids = [int(x) for x in question.split()]
answer_ids = [int(x) for x in answer.split()]
answer_ids.append(EOS_ID)
for bucket_id, (question_size, answer_size) in enumerate(buckets):
if len(question_ids) < question_size and len(answer_ids) < answer_size:
data_set[bucket_id].append([question_ids, answer_ids])
break
question, answer = question_file.readline(), answer_file.readline()
print("question: {}, answer: {}".format(question, answer))
return data_set
model = seq2seq_model.Seq2SeqModel(question_vocab_size=vocabulary_encode_size, answer_vocab_size=vocabulary_decode_size,
buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm= 5.0,
batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.97, forward_only=False)
train_set = read_data(train_encode_vec, train_decode_vec)
def get_batch(data, bucket_id):
"""Get a random batch of data from the specified bucket, prepare for step.
To feed data in step(..) it must be a list of batch-major vectors, while
data here contains single length-major cases. So the main logic of this
function is to re-index data cases to be in the proper format for feeding.
Args:
data: 词向量列表,如[[[4,4],[5,6,8]]]
bucket_id: 桶编号,值取自桶对话占比
Returns:
The triple (encoder_inputs, decoder_inputs, answer_weights) for
the constructed batch that has the proper format to call step(...) later.
"""
# 问题和答案的数据量:桶的话数buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
encoder_size, decoder_size = self.buckets[bucket_id]
# 生成问题和答案的存储器
encoder_inputs, decoder_inputs = [], []
# Get a random batch of encoder and decoder inputs from data,
# pad them if needed, reverse encoder inputs and add GO to decoder.
for _ in xrange(self.batch_size):
# 从问答数据集中随机选取问答
encoder_input, decoder_input = random.choice(data[bucket_id])
# 问题末尾添加PAD_ID并反向排序
encoder_pad = [word_to_vec.PAD_ID] * (encoder_size - len(encoder_input))
encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
# 答案添加GO_ID和PAD_ID
decoder_pad_size = decoder_size - len(decoder_input) - 1
decoder_inputs.append([word_to_vec.GO_ID] + decoder_input +
[word_to_vec.PAD_ID] * decoder_pad_size)
# 问题,答案,权重批量数据
batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
# 批量问题
for length_idx in xrange(encoder_size):
batch_encoder_inputs.append(
np.array([encoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(self.batch_size)], dtype=np.int32))
# 批量答案
for length_idx in xrange(decoder_size):
batch_decoder_inputs.append(
np.array([decoder_inputs[batch_idx][length_idx]
for batch_idx in xrange(self.batch_size)], dtype=np.int32))
# 答案权重即Attention机制
batch_weight = np.ones(self.batch_size, dtype=np.float32)
for batch_idx in xrange(self.batch_size):
# 若答案为PAD则权重设置为0,因为是添加的ID
# 其他的设置为1
if length_idx < decoder_size - 1:
answer = decoder_inputs[batch_idx][length_idx + 1]
if length_idx == decoder_size - 1 or answer == word_to_vec.PAD_ID:
batch_weight[batch_idx] = 0.0
batch_weights.append(batch_weight)
print("encoder inputs: {}, decoder inputs: {}, answer weights: {}".format(encoder_inputs, decoder_inputs, answer_weights))
get_batch(data_set, bucket_id)
- Result
encoder inputs: [array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
dtype=int32), array([ 0, 0, 69, 0, 0, 69, 0, 0, 0, 0, 6, 0, 0, 0, 0, 78, 0,
0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 12, 0, 0, 0, 0, 12, 4, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 71, 69, 0, 0, 6, 0, 0, 0, 0, 0, 0], dtype=int32), array([ 0, 0, 28, 0, 0, 28, 0, 32, 0, 0, 62, 0, 0,
0, 0, 77, 0, 0, 8, 8, 11, 0, 9, 0, 0, 0,
0, 0, 0, 0, 12, 10, 0, 106, 0, 0, 5, 0, 0,
0, 5, 5, 216, 0, 0, 11, 0, 0, 0, 0, 78, 0,
0, 7, 28, 0, 0, 62, 0, 0, 0, 0, 0, 0],
dtype=int32), array([ 0, 0, 4, 0, 0, 4, 90, 32, 20, 0, 62, 4, 267,
20, 119, 8, 14, 0, 22, 10, 11, 246, 8, 0, 5, 119,
90, 0, 0, 14, 5, 79, 102, 5, 51, 0, 45, 113, 37,
90, 55, 45, 18, 0, 0, 11, 0, 84, 119, 8, 77, 0,
0, 4, 4, 0, 0, 62, 0, 0, 20, 113, 0, 23],
dtype=int32), array([309, 58, 27, 41, 41, 27, 31, 32, 13, 122, 4, 87, 66,
13, 119, 22, 14, 0, 137, 31, 11, 245, 24, 217, 7, 119,
31, 219, 212, 14, 76, 19, 13, 4, 5, 6, 44, 113, 143,
31, 4, 44, 17, 41, 0, 11, 219, 84, 119, 22, 16, 226,
213, 27, 27, 120, 219, 4, 244, 219, 13, 113, 309, 4],
dtype=int32)]
decoder inputs: [array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
dtype=int32), array([ 69, 56, 4, 343, 168, 4, 195, 5, 177, 353, 98, 5, 438,
249, 5, 36, 16, 5, 45, 30, 26, 387, 4, 338, 36, 5,
195, 164, 5, 16, 95, 74, 258, 4, 72, 6, 216, 314, 5,
195, 4, 203, 5, 168, 5, 26, 164, 443, 5, 361, 177, 352,
5, 4, 4, 88, 164, 98, 16, 164, 249, 314, 69, 4],
dtype=int32), array([ 12, 2, 121, 2, 2, 121, 195, 5, 78, 2, 67, 5, 193,
14, 5, 29, 11, 5, 99, 8, 26, 6, 37, 2, 8, 5,
195, 164, 5, 11, 107, 74, 259, 8, 84, 29, 217, 2, 5,
195, 9, 204, 5, 2, 5, 6, 164, 190, 5, 362, 78, 2,
5, 14, 121, 2, 164, 67, 181, 164, 14, 2, 12, 23],
dtype=int32), array([ 10, 0, 122, 0, 0, 122, 2, 2, 104, 0, 8, 2, 2,
2, 2, 12, 2, 2, 26, 7, 6, 17, 81, 0, 29, 2,
2, 2, 2, 2, 8, 34, 17, 101, 95, 12, 2, 0, 2,
2, 8, 205, 2, 0, 2, 33, 2, 2, 2, 363, 104, 0,
2, 263, 122, 0, 2, 8, 13, 2, 2, 0, 10, 13],
dtype=int32), array([ 55, 0, 2, 0, 0, 2, 0, 0, 77, 0, 4, 0, 0,
0, 0, 13, 0, 0, 26, 2, 33, 183, 113, 0, 12, 0,
0, 0, 0, 0, 6, 7, 8, 2, 2, 2, 0, 0, 0,
0, 157, 206, 0, 0, 0, 15, 0, 0, 0, 364, 77, 0,
0, 264, 2, 0, 0, 4, 385, 0, 0, 0, 55, 33],
dtype=int32), array([469, 0, 0, 0, 0, 0, 0, 0, 9, 0, 111, 0, 0,
0, 0, 360, 0, 0, 65, 0, 2, 38, 114, 0, 2, 0,
0, 0, 0, 0, 63, 2, 31, 0, 0, 0, 0, 0, 0,
0, 6, 31, 0, 0, 0, 2, 0, 0, 0, 176, 9, 0,
0, 32, 0, 0, 0, 111, 165, 0, 0, 0, 469, 4],
dtype=int32), array([ 9, 0, 0, 0, 0, 0, 0, 0, 178, 0, 157, 0, 0,
0, 0, 153, 0, 0, 2, 0, 0, 9, 198, 0, 0, 0,
0, 0, 0, 0, 119, 0, 260, 0, 0, 0, 0, 0, 0,
0, 71, 116, 0, 0, 0, 0, 0, 0, 0, 2, 178, 0,
0, 2, 0, 0, 0, 157, 386, 0, 0, 0, 9, 13],
dtype=int32), array([ 27, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0,
0, 0, 2, 0, 0, 0, 0, 0, 27, 4, 0, 0, 0,
0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0,
0, 8, 117, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,
0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 27, 180],
dtype=int32), array([ 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 134, 2, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 180],
dtype=int32), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
dtype=int32)]
answer weights: [array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32), array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32), array([1., 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1.], dtype=float32), array([1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1.,
0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 1.], dtype=float32), array([1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0.,
0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1.,
0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1.], dtype=float32), array([1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0.,
0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1.], dtype=float32), array([1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0.,
0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0.,
0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1.], dtype=float32), array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
- Analysis
(1) 对问题和答案的向量重新整理,符合桶尺寸则保持对话尺寸,若不符合桶设定尺寸,则进行填充处理,问题使用PAD_ID
填充,答案使用GO_ID
和PAD_ID
填充;
(2) 对问题和答案向量填充整理后,使用Attention机制,对答案进行权重分配,答案中的PAD_ID
权重为0,其他对应的为1;
(3)get_batch()
处理词向量,返回问题
,答案
和答案权重
数据;返回结果如上结果:encoder_inputs
,decoder_inputs
,answer_weights
.
1.2 处理问答及答案权重
- Demo
def step(self, session, encoder_inputs, decoder_inputs, answer_weights,
bucket_id, forward_only):
"""Run a step of the model feeding the given inputs.
参数:
session: tensorflow 会话.
encoder_inputs: 问题向量列表
decoder_inputs: 回答向量列表
answer_weights: 答案权重列表
bucket_id: 桶编号which bucket of the model to use.
forward_only: 前向或反向运算标志位
返回:
一个由梯度范数组成的三重范数(如果不使用反向传播,则为无)。
平均困惑度和输出
Raises:
ValueError: if length of encoder_inputs, decoder_inputs, or
answer_weights disagrees with bucket size for the specified bucket_id.
"""
# 问答匹配桶尺寸
encoder_size, decoder_size = self.buckets[bucket_id]
if len(encoder_inputs) != encoder_size:
raise ValueError("Encoder length must be equal to the one in bucket,"
" %d != %d." % (len(encoder_inputs), encoder_size))
if len(decoder_inputs) != decoder_size:
raise ValueError("Decoder length must be equal to the one in bucket,"
" %d != %d." % (len(decoder_inputs), decoder_size))
if len(answer_weights) != decoder_size:
raise ValueError("Weights length must be equal to the one in bucket,"
" %d != %d." % (len(answer_weights), decoder_size))
# Input feed: encoder inputs, decoder inputs, answer_weights, as provided.
input_feed = {}
for l in xrange(encoder_size):
input_feed[self.encoder_inputs[l].name] = encoder_inputs[l]
for l in xrange(decoder_size):
input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
input_feed[self.answer_weights[l].name] = answer_weights[l]
# Since our answers are decoder inputs shifted by one, we need one more.
last_answer = self.decoder_inputs[decoder_size].name
input_feed[last_answer] = np.zeros([self.batch_size], dtype=np.int32)
# Output feed: depends on whether we do a backward step or not.
if not forward_only:
output_feed = [self.updates[bucket_id], # Update Op that does SGD.
self.gradient_norms[bucket_id], # Gradient norm.
self.losses[bucket_id]] # Loss for this batch.
else:
output_feed = [self.losses[bucket_id]] # Loss for this batch.
for l in xrange(decoder_size): # Output logits.
output_feed.append(self.outputs[bucket_id][l])
outputs = session.run(output_feed, input_feed)
if not forward_only:
return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.
else:
return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs.
- Result
反向传播
gradient norm: 0.14181189239025116, step loss: 0.024454593658447266
前向传播
部分数据
[array([[-8.61980557e-01, 7.83682689e-02, -6.27610302e+00,
4.22952592e-01, -6.75087309e+00, 4.07743502e+00,
6.32381868e+00, 7.75186586e+00, -9.86352444e+00,
...
-9.14800549e+00, 8.39417934e+00, -6.37120008e+00,
-3.18058801e+00, -4.23562908e+00, -3.24618030e+00,
4.26430941e+00, -4.64457321e+00, -2.33847499e+00,
1.18018208e+01, -3.61384606e+00]], dtype=float32), array......
- Analysis
(1) 根据输入的问答向量列表,分配语料桶,处理问答向量列表,并生成新的输入字典(dict),input_feed = {}
;
(2) 输出字典(dict),ouput_feed = {}
,根据是否使用反向传播获得参数,使用反向传播,output_feed
存储更新的梯度范数,损失,不使用反向传播,则只存储损失;
(3) 最终的输出为分两种情况,使用反向传播,返回梯度范数,损失,如反向传播不使用反向传播,返回损失和输出的向量(用于加载模型,测试效果),如前向传播;
1.3 seq2seq模型
class Seq2SeqModel(object):
def __init__(self, source_vocab_size, target_vocab_size, buckets, size,
num_layers, max_gradient_norm, batch_size, learning_rate,
learning_rate_decay_factor, use_lstm=False,
num_samples=512, forward_only=False):
self.source_vocab_size = source_vocab_size
self.target_vocab_size = target_vocab_size
self.buckets = buckets
self.batch_size = batch_size
self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor)
self.global_step = tf.Variable(0, trainable=False)
output_projection = None
softmax_loss_function = None
if num_samples > 0 and num_samples < self.target_vocab_size:
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
w_t = tf.transpose(w)
b = tf.get_variable("proj_b", [self.target_vocab_size])
output_projection = (w, b)
def sampled_loss(labels,logits):
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, b, labels,logits,num_samples,self.target_vocab_size)
softmax_loss_function = sampled_loss
single_cell = tf.contrib.rnn.GRUCell(size)
if use_lstm:
single_cell = tf.contrib.rnn.BasicLSTMCell(size)
cell = single_cell
if num_layers > 1:
cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
tem_cell=copy.deepcopy(cell)
return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
encoder_inputs, decoder_inputs, tem_cell,
num_encoder_symbols=source_vocab_size,
num_decoder_symbols=target_vocab_size,
embedding_size=size,
output_projection=output_projection,
feed_previous=do_decode)
self.encoder_inputs = []
self.decoder_inputs = []
self.target_weights = []
for i in xrange(buckets[-1][0]): # Last bucket is the biggest one.
self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
name="encoder{0}".format(i)))
for i in xrange(buckets[-1][1] + 1):
self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
name="decoder{0}".format(i)))
self.target_weights.append(tf.placeholder(tf.float32, shape=[None],
name="weight{0}".format(i)))
targets = [self.decoder_inputs[i + 1]
for i in xrange(len(self.decoder_inputs) - 1)]
# Training outputs and losses.
if forward_only:
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
self.encoder_inputs, self.decoder_inputs, targets,
self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True),
softmax_loss_function=softmax_loss_function)
if output_projection is not None:
for b in xrange(len(buckets)):
self.outputs[b] = [
tf.matmul(output, output_projection[0]) + output_projection[1]
for output in self.outputs[b]
]
else:
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
self.encoder_inputs, self.decoder_inputs, targets,
self.target_weights, buckets,
lambda x, y: seq2seq_f(x, y, False),
softmax_loss_function=softmax_loss_function)
params = tf.trainable_variables()
if not forward_only:
self.gradient_norms = []
self.updates = []
opt = tf.train.GradientDescentOptimizer(self.learning_rate)
for b in xrange(len(buckets)):
gradients = tf.gradients(self.losses[b], params)
clipped_gradients, norm = tf.clip_by_global_norm(gradients,
max_gradient_norm)
self.gradient_norms.append(norm)
self.updates.append(opt.apply_gradients(
zip(clipped_gradients, params), global_step=self.global_step))
self.saver = tf.train.Saver(tf.all_variables())
2 训练&保存模型
model = seq2seq_model.Seq2SeqModel(question_vocab_size=vocabulary_encode_size, answer_vocab_size=vocabulary_decode_size,
buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm= 5.0,
batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.97, forward_only=False)
# 使用GPU配置
# # config = tf.ConfigProto()
# # config.gpu_options.allocator_type = 'BFC' # 防止 out of memory
if __name__ == "__main__":
with tf.Session() as sess:
# with tf.Session(config=config) as sess:
# 检查是否有已存在的训练模型
# 有模型则获取模型轮数,接着训练
# 没有模型则从开始训练
ckpt = tf.train.get_checkpoint_state('./models')
if ckpt != None:
train_turn = ckpt.model_checkpoint_path.split('-')[1]
print("model path: {}, train turns: {}".format(ckpt.model_checkpoint_path, train_turn))
model.saver.restore(sess, ckpt.model_checkpoint_path)
total_step = int(train_turn)
else:
sess.run(tf.global_variables_initializer())
total_step = 0
train_set = read_data(train_encode_vec, train_decode_vec)
test_set = read_data(test_encode_vec, test_decode_vec)
train_bucket_sizes = [len(train_set[b]) for b in range(len(buckets))]
# print("train bucket sizes: {}".format(train_bucket_sizes))
train_total_size = float(sum(train_bucket_sizes))
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]
loss = 0.0
# total_step = int(train_turn)
previous_losses = []
# 一直训练,每过一段时间保存一次模型
while True:
random_number_01 = np.random.random_sample()
# get minimum i as bucket id when value > randmom value
bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
encoder_inputs, decoder_inputs, answer_weights = model.get_batch(train_set, bucket_id)
# print("encoder inputs: {}, decoder inputs: {}, answer weights: {}".format(encoder_inputs, decoder_inputs, answer_weights))
gradient_norm, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, answer_weights, bucket_id, False)
print("gradient norm: {}, step loss: {}".format(gradient_norm, step_loss))
loss += step_loss / 500
total_step += 1
print("total step: {}".format(total_step))
if total_step % 500 == 0:
print("global step: {}, learning rate: {}, loss: {}".format(model.global_step.eval(), model.learning_rate.eval(), loss))
# 如果模型没有得到提升,减小learning rate
if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
sess.run(model.learning_rate_decay_op)
previous_losses.append(loss)
# 保存模型
checkpoint_path = "./models/chatbot_seq2seq.ckpt"
model.saver.save(sess, checkpoint_path, global_step=model.global_step)
loss = 0.0
# 使用测试数据评估模型
for bucket_id in range(len(buckets)):
if len(test_set[bucket_id]) == 0:
continue
encoder_inputs, decoder_inputs, answer_weights = model.get_batch(test_set, bucket_id)
_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, answer_weights, bucket_id, True)
eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
print("bucket id: {}, eval ppx: {}".format(bucket_id, eval_ppx))
- Result
total step: 1
total step: 2
total step: 3
total step: 4
total step: 5
...
global step: 500, learning rate: 0.5, loss: 2.574068747580052
bucket id: 0, eval ppx: 14176.588030763274
bucket id: 1, eval ppx: 3650.0026667220773
bucket id: 2, eval ppx: 4458.454110999805
bucket id: 3, eval ppx: 5290.083583183104
3 载入模型&测试
- Demo
import tensorflow as tf # 0.12
import seq2seq_model
import os
import numpy as np
PAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3
train_encode_vocabulary = './data/word2vec/train_question_encode_vocabulary'
train_decode_vocabulary = './data/word2vec/train_answer_decode_vocabulary'
# train_encode_vec = './data/word2vec/train_question_encode.vec'
# train_decode_vec = './data/word2vec/train_answer_decode.vec'
# test_encode_vec = './data/word2vec/test_question_encode.vec'
# test_decode_vec = './data/word2vec/test_answer_decode.vec'
def read_vocabulary(input_file):
tmp_vocab = []
with open(input_file, "r") as f:
tmp_vocab.extend(f.readlines())
tmp_vocab = [line.strip() for line in tmp_vocab]
vocab = dict([(x, y) for (y, x) in enumerate(tmp_vocab)])
return vocab, tmp_vocab
vocab_en, _, = read_vocabulary(train_encode_vocabulary)
_, vocab_de, = read_vocabulary(train_decode_vocabulary)
# 词汇表大小5000
vocabulary_encode_size = 470
vocabulary_decode_size = 470
buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
layer_size = 256 # 每层大小
num_layers = 3 # 层数
batch_size = 1
# model = seq2seq_model.Seq2SeqModel(source_vocab_size=vocabulary_encode_size, target_vocab_size=vocabulary_decode_size,
# buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm= 5.0,
# batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.99, forward_only=True)
model = seq2seq_model.Seq2SeqModel(question_vocab_size=vocabulary_encode_size, answer_vocab_size=vocabulary_decode_size,
buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm= 5.0,
batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.99, forward_only=True)
# 训练取一个数据
model.batch_size = 1
with tf.Session() as sess:
# 载入模型
ckpt = tf.train.get_checkpoint_state('./models')
if ckpt != None:
print(ckpt.model_checkpoint_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)
else:
print("没找到模型")
while True:
input_string = input('me > ')
# 退出
if input_string == 'quit':
exit()
input_string_vec = []
for words in input_string.strip():
input_string_vec.append(vocab_en.get(words, UNK_ID))
bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])
encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
print("output logits: {}".format(output_logits))
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
if EOS_ID in outputs:
outputs = outputs[:outputs.index(EOS_ID)]
response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])
print('AI > ' + response)
- Result
4 总结
(1) 该聊天机器人使用bucket
桶结构,即指定问答数据的长度,匹配符合的桶,在桶中进行存取数据;
(2) 该seq2seq模型使用Tensorflow
时,未能建立独立标识的图结构,在进行后台封装过程中出现图为空的现象;
[参考文献]
[1]http://blog.topspeedsnail.com/archives/10735/comment-page-1#comment-1161%E3%80%82
[2]https://github.com/tensorflow/nmt
[3]https://github.com/tensorflow/tensorflow/blob/b19d6657070bbf1df5706195a0bf3a92cbf371fc/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
[4]https://github.com/tensorflow/tensorflow/tree/b19d6657070bbf1df5706195a0bf3a92cbf371fc/tensorflow/contrib/seq2seq/python/ops