# [置顶] 如何用TensorFlow训练聊天机器人（附github）

12209人阅读 评论(13)

## 数据预处理

limit = {
'maxq': 10,
'minq': 0,
'maxa': 8,
'mina': 3
}

UNK = 'unk'
GO = '<go>'
EOS = '<eos>'
VOCAB_SIZE = 1000

def filter_data(sequences):
filtered_q, filtered_a = [], []
raw_data_len = len(sequences) // 2

for i in range(0, len(sequences), 2):
qlen, alen = len(sequences[i].split(' ')), len(sequences[i + 1].split(' '))
if qlen >= limit['minq'] and qlen <= limit['maxq']:
if alen >= limit['mina'] and alen <= limit['maxa']:
filtered_q.append(sequences[i])
filtered_a.append(sequences[i + 1])
filt_data_len = len(filtered_q)
filtered = int((raw_data_len - filt_data_len) * 100 / raw_data_len)
print(str(filtered) + '% filtered from original data')

return filtered_q, filtered_a

def index_(tokenized_sentences, vocab_size):
freq_dist = nltk.FreqDist(itertools.chain(*tokenized_sentences))
vocab = freq_dist.most_common(vocab_size)
index2word = [GO] + [EOS] + [UNK] + [PAD] + [x[0] for x in vocab]
word2index = dict([(w, i) for i, w in enumerate(index2word)])
return index2word, word2index, freq_dist

def zero_pad(qtokenized, atokenized, w2idx):
data_len = len(qtokenized)
# +2 dues to '<go>' and '<eos>'
idx_q = np.zeros([data_len, limit['maxq']], dtype=np.int32)
idx_a = np.zeros([data_len, limit['maxa'] + 2], dtype=np.int32)
idx_o = np.zeros([data_len, limit['maxa'] + 2], dtype=np.int32)

for i in range(data_len):
q_indices = pad_seq(qtokenized[i], w2idx, limit['maxq'], 1)
a_indices = pad_seq(atokenized[i], w2idx, limit['maxa'], 2)
o_indices = pad_seq(atokenized[i], w2idx, limit['maxa'], 3)
idx_q[i] = np.array(q_indices)
idx_a[i] = np.array(a_indices)
idx_o[i] = np.array(o_indices)

return idx_q, idx_a, idx_o

if flag == 1:
indices = []
elif flag == 2:
indices = [lookup[GO]]
elif flag == 3:
indices = []
for word in seq:
if word in lookup:
indices.append(lookup[word])
else:
indices.append(lookup[UNK])
if flag == 1:
return indices + [lookup[PAD]] * (maxlen - len(seq))
elif flag == 2:
return indices + [lookup[EOS]] + [lookup[PAD]] * (maxlen - len(seq))
elif flag == 3:
return indices + [lookup[EOS]] + [lookup[PAD]] * (maxlen - len(seq) + 1)

## 构建图

encoder_inputs = tf.placeholder(dtype=tf.int32, shape=[batch_size, sequence_length])
decoder_inputs = tf.placeholder(dtype=tf.int32, shape=[batch_size, sequence_length])
targets = tf.placeholder(dtype=tf.int32, shape=[batch_size, sequence_length])
weights = tf.placeholder(dtype=tf.float32, shape=[batch_size, sequence_length])

cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)

results, states = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(
tf.unstack(encoder_inputs, axis=1),
tf.unstack(decoder_inputs, axis=1),
cell,
num_encoder_symbols,
num_decoder_symbols,
embedding_size,
feed_previous=False
)

logits = tf.stack(results, axis=1)
loss = tf.contrib.seq2seq.sequence_loss(logits, targets=targets, weights=weights)
pred = tf.argmax(logits, axis=2)
train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

## 创建会话

with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(model_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
epoch = 0
while epoch < 5000000:
epoch = epoch + 1
print("epoch:", epoch)
for step in range(0, 1):
print("step:", step)
train_encoder_inputs = train_x[step * batch_size:step * batch_size + batch_size, :]
train_decoder_inputs = train_y[step * batch_size:step * batch_size + batch_size, :]
train_targets = train_target[step * batch_size:step * batch_size + batch_size, :]
op = sess.run(train_op, feed_dict={encoder_inputs: train_encoder_inputs, targets: train_targets,
weights: train_weights, decoder_inputs: train_decoder_inputs})
cost = sess.run(loss, feed_dict={encoder_inputs: train_encoder_inputs, targets: train_targets,
weights: train_weights, decoder_inputs: train_decoder_inputs})
print(cost)
step = step + 1
if epoch % 100 == 0:
saver.save(sess, model_dir + '/model.ckpt', global_step=epoch + 1)

## 预测

with tf.device('/cpu:0'):
batch_size = 1
sequence_length = 10
num_encoder_symbols = 1004
num_decoder_symbols = 1004
embedding_size = 256
hidden_size = 256
num_layers = 2

encoder_inputs = tf.placeholder(dtype=tf.int32, shape=[batch_size, sequence_length])
decoder_inputs = tf.placeholder(dtype=tf.int32, shape=[batch_size, sequence_length])

targets = tf.placeholder(dtype=tf.int32, shape=[batch_size, sequence_length])
weights = tf.placeholder(dtype=tf.float32, shape=[batch_size, sequence_length])

cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)

results, states = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(
tf.unstack(encoder_inputs, axis=1),
tf.unstack(decoder_inputs, axis=1),
cell,
num_encoder_symbols,
num_decoder_symbols,
embedding_size,
feed_previous=True,
)
logits = tf.stack(results, axis=1)
pred = tf.argmax(logits, axis=2)

saver = tf.train.Saver()
with tf.Session() as sess:
module_file = tf.train.latest_checkpoint('./model/')
saver.restore(sess, module_file)
map = Word_Id_Map()

encoder_input = encoder_input + [3 for i in range(0, 10 - len(encoder_input))]
encoder_input = np.asarray([np.asarray(encoder_input)])
decoder_input = np.zeros([1, 10])
print('encoder_input : ', encoder_input)
print('decoder_input : ', decoder_input)
pred_value = sess.run(pred, feed_dict={encoder_inputs: encoder_input, decoder_inputs: decoder_input})
print(pred_value)
sentence = map.ids2sentence(pred_value[0])
print(sentence)

[‘how’, ‘do’, ‘you’, ‘do’, ‘this’, ‘’, ‘’, ‘’, ‘’, ‘’]。

## github

https://github.com/sea-boat/seq2seq_chatbot.git

========广告时间========

=========================

《LSTM神经网络》

《循环神经网络》

《深度学习的seq2seq模型》

《机器学习之神经网络》

《GRU神经网络》

10
0

作者
https://github.com/sea-boat

公众号：（内容包括分布式、机器学习、深度学习、NLP、Java深度、Java并发核心、JDK源码、Tomcat内核等等）

微信：

打赏作者

如果您觉得作者写的文章有帮助到您，您可以打赏作者一瓶汽水(*^__^*)

个人资料
• 访问：1062039次
• 积分：14069
• 等级：
• 排名：第1049名
• 原创：326篇
• 转载：5篇
• 译文：1篇
• 评论：348条
博客专栏
 JDK源码 文章：35篇 阅读：99016
 机器学习&深度学习 文章：38篇 阅读：81760
 自然语言处理 文章：13篇 阅读：34473
 mysql协议 文章：20篇 阅读：30704
 Hazelcast 文章：5篇 阅读：27533
 通信框架Tribes 文章：8篇 阅读：22070
 集群 文章：16篇 阅读：62603
 tomcat内核 文章：83篇 阅读：306069
 Java并发 文章：22篇 阅读：70340
 java开源研究 文章：40篇 阅读：164859
最新评论