本例参考某教学视频和某知乎文章,下载完他的github代码后找不到文章地址了,罪过罪过,自己敲了一遍代码并跑了一遍,添加了一些注释,再通过整理博客的方式加深理解。
该例子的主要目的是通过seq2seq训练一个模型,解决将输入字符串(乱序)按字母顺序排序输出的问题,例如“open”->enop。当然用简单的算法也可以很方便的将字符串排序输出,这不主要目的是为了理解seq2seq工作原理嘛!用个简单的例子易于理解!
查看TensorFlow版本
from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow.python.layers.core import Dense
# Check TensorFlow Version
assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer'
print('TensorFlow Version: {}'.format(tf.__version__))
数据加载
import numpy as np
import time
import tensorflow as tf
with open('data/letters_source.txt','r',encoding='utf-8') as f:
source_data=f.read()
with open('data/letters_target.txt','r',encoding='utf-8') as f:
target_data=f.read()
#数据预览
source_data.split('\n')[:10]
#数据预览
target_data.split('\n')[:10]
数据预处理
def extract_character_vocab(data):
'''构造映射表'''
special_words=['<PAD>','<UNK>','<GO>','<EOS>']
#set()创建无序不重复集合
set_words=list(set([character for line in data.split('\n') for character in line]))
#这里要把四个特殊字符添加进词典
int_to_vocab={idx:word for idx, word in enumerate(special_words+set_words)}
vocab_to_int={word:idx for idx, word in int_to_vocab.items()}
return int_to_vocab,vocab_to_int
#构造映射表
source_int_to_letter,source_letter_to_int=extract_character_vocab(source_data)
target_int_to_letter,target_letter_to_int=extract_character_vocab(target_data)
#对字母进行转换 ,get()函数返回letter所在下表索引,如果letter不在字典中就返回source_letter_to_int['<UNK>']
source_int=[ [source_letter_to_int.get(letter,source_letter_to_int['<UNK>'])
for letter in line] for line in source_data.split('\n') ]
target_int=[[target_letter_to_int.get(letter,target_letter_to_int['<UNK>'])
for letter in line] +[target_letter_to_int['<EOS>']] for line in target_data.split('\n')]
#查看一下转换结果
source_int[:10]
target_int[:10]
构建模型
输入层
def get_inputs():
'''获得模型输入tensor,6个张量'''
inputs=tf.placeholder(tf.int32,[None,None],name='inputs') #输入
targets=tf.placeholder(tf.int32,[None,None],name='targets') #输出
learning_rate=tf.placeholder(tf.float32,name='learning_rate') #学习率
#定义target序列最大长度(之后target_sequence_length和source_sequence_length会作为feed_dice的参数)
target_sequence_length=tf.placeholder(tf.int32,(None,),name='target_sequence_length') #目标seq长度
max_target_sequence_length=tf.reduce_max(target_sequence_length,name='max_target_len') #目标seq最大长度
source_sequence_length=tf.placeholder(tf.int32,(None,),name='source_sequence_length') #源数据seq长度
return inputs,targets,learning_rate,target_sequence_length,max_target_sequence_length,source_sequence_length
Encoder
在Encoder端,我们需要进行两步,第一步要对我们的输入进行Embedding,再把Embedding以后的向量传给RNN进行处理。
我们来看一个栗子,假如我们有一个batch=2,sequence_length=5的样本,features = [[1,2,3,4,5],[6,7,8,9,10]],使用tf.contrib.layers.embed_sequence(features,vocab_size=n_words, embed_dim=10) 那么我们会得到一个2 x 5 x 10的输出,其中features中的每个数字都被embed成了一个10维向量。
def get_encoder_layer(input_data, rnn_size, num_layers,
source_sequence_length, source_vocab_size,
encoding_embedding_size):
'''
构造Encoder层
参数说明:
- input_data: 输入tensor
- rnn_size: rnn隐层结点数量
- num_layers: 堆叠的rnn cell数量
- source_sequence_length: 源数据的序列长度
- source_vocab_size: 源数据的词典大小
- encoding_embedding_size: embedding的大小
'''
# Encoder embedding
encoder_embed_input = tf.contrib.layers.embed_sequence(input_data, source_vocab_size, encoding_embedding_size)
# RNN cell
def get_lstm_cell(rnn_size):
lstm_cell = tf.contrib.rnn.LSTMCell(rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
return lstm_cell
cell = tf.contrib.rnn.MultiRNNCell([get_lstm_cell(rnn_size) for _ in range(num_layers)])
encoder_output, encoder_state = tf.nn.dynamic_rnn(cell, encoder_embed_input,
sequence_length=source_sequence_length, dtype=tf.float32)
return encoder_output, encoder_state
Decoder
对target数据进行预处理
def process_decoder_input(data,vocab_to_int,batch_size):
'''补充<GO>,并移除最后一个字符'''
#cut掉最后一个字符
ending=tf.strided_slice(data,[0,0],[batch_size,-1],[1,1])
decoder_input=tf.concat([tf.fill([batch_size,1],vocab_to_int['<GO>']),ending],1) #按第一维进行合并
return decoder_input
对数据进行embedding【*****】
同样地,我们还需要对target数据进行embedding,使得它们能够传入Decoder中的RNN。
def decoding_layer(target_letter_to_int, decoding_embedding_size, num_layers, rnn_size,
target_sequence_length, max_target_sequence_length, encoder_state, decoder_input):
'''
构造Decoder层
参数:
- target_letter_to_int: target数据的映射表
- decoding_embedding_size: embed向量大小
- num_layers: 堆叠的RNN单元数量
- rnn_size: RNN单元的隐层结点数量
- target_sequence_length: target数据序列长度
- max_target_sequence_length: target数据序列最大长度
- encoder_state: encoder端编码的状态向量
- decoder_input: decoder端输入
'''
# 1. Embedding
target_vocab_size = len(target_letter_to_int) #目标词向量长度
decoder_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))
decoder_embed_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_input) #主要是选取一个张量里面索引对应的元素
# 2. 构造Decoder中的RNN单元
def get_decoder_cell(rnn_size):
decoder_cell = tf.contrib.rnn.LSTMCell(rnn_size,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
return decoder_cell
cell = tf.contrib.rnn.MultiRNNCell([get_decoder_cell(rnn_size) for _ in range(num_layers)])
# 3. Output全连接层
output_layer = Dense(target_vocab_size,
kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))
# 4. Training decoder
with tf.variable_scope("decode"):
# 得到help对象
training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embed_input,
sequence_length=target_sequence_length,
time_major=False)
# 构造decoder
training_decoder = tf.contrib.seq2seq.BasicDecoder(cell,
training_helper,
encoder_state,
output_layer)
training_decoder_output,_ , _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,
impute_finished=True,
maximum_iterations=max_target_sequence_length)
# 5. Predicting decoder
# 与training共享参数
with tf.variable_scope("decode", reuse=True):
# 创建一个常量tensor并复制为batch_size的大小
start_tokens = tf.tile(tf.constant([target_letter_to_int['<GO>']], dtype=tf.int32), [batch_size],
name='start_tokens')
predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings,
start_tokens,
target_letter_to_int['<EOS>'])
predicting_decoder = tf.contrib.seq2seq.BasicDecoder(cell,
predicting_helper,
encoder_state,
output_layer)
predicting_decoder_output,_, _ = tf.contrib.seq2seq.dynamic_decode(predicting_decoder,
impute_finished=True,
maximum_iterations=max_target_sequence_length)
return training_decoder_output, predicting_decoder_output
Seq2Seq【*****】
上面已经构建完成Encoder和Decoder,下面将这两部分连接起来,构建seq2seq模型
def seq2seq_model(input_data, targets, lr, target_sequence_length,
max_target_sequence_length, source_sequence_length,
source_vocab_size, target_vocab_size,
encoder_embedding_size, decoder_embedding_size,
rnn_size, num_layers):
# 获取encoder的状态输出
_, encoder_state = get_encoder_layer(input_data,
rnn_size,
num_layers,
source_sequence_length,
source_vocab_size,
encoding_embedding_size)
# 预处理后的decoder输入
decoder_input = process_decoder_input(targets, target_letter_to_int, batch_size)
# 将状态向量与输入传递给decoder
training_decoder_output, predicting_decoder_output = decoding_layer(target_letter_to_int,
decoding_embedding_size,
num_layers,
rnn_size,
target_sequence_length,
max_target_sequence_length,
encoder_state,
decoder_input)
return training_decoder_output, predicting_decoder_output
设置超参数
# 超参数
# Number of Epochs
epochs = 60
# Batch Size
batch_size = 128
# RNN Size
rnn_size = 50
# Number of Layers
num_layers = 2
# Embedding Size
encoding_embedding_size = 15
decoding_embedding_size = 15
# Learning Rate
learning_rate = 0.001
构造graph
# 构造graph
train_graph = tf.Graph()
with train_graph.as_default():
# 获得模型输入
input_data, targets, lr, target_sequence_length, max_target_sequence_length, source_sequence_length = get_inputs()
training_decoder_output, predicting_decoder_output = seq2seq_model(input_data,targets,lr,target_sequence_length,
max_target_sequence_length,source_sequence_length,
len(source_letter_to_int),len(target_letter_to_int),
encoding_embedding_size,decoding_embedding_size,
rnn_size,num_layers)
#tf.identity它返回一个和输入的 tensor 大小和数值都一样的 tensor ,类似于 y=x 操作
training_logits = tf.identity(training_decoder_output.rnn_output, 'logits')
predicting_logits = tf.identity(predicting_decoder_output.sample_id, name='predictions')
masks = tf.sequence_mask(target_sequence_length, max_target_sequence_length, dtype=tf.float32, name='masks')
with tf.name_scope("optimization"):
# Loss function
cost = tf.contrib.seq2seq.sequence_loss(
training_logits,
targets,
masks)
# Optimizer
optimizer = tf.train.AdamOptimizer(lr)
# Gradient Clipping
gradients = optimizer.compute_gradients(cost)
capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None]
train_op = optimizer.apply_gradients(capped_gradients)
设置Batches
def pad_sentence_batch(sentence_batch, pad_int):
'''
对batch中的序列进行补全,保证batch中的每行都有相同的sequence_length
参数:
- sentence batch
- pad_int: <PAD>对应索引号
'''
max_sentence = max([len(sentence) for sentence in sentence_batch])
return [sentence + [pad_int] * (max_sentence - len(sentence)) for sentence in sentence_batch]
def get_batches(targets, sources, batch_size, source_pad_int, target_pad_int):
'''
定义生成器,用来获取batch
'''
for batch_i in range(0, len(sources)//batch_size): #//取整数
start_i = batch_i * batch_size
sources_batch = sources[start_i:start_i + batch_size]
targets_batch = targets[start_i:start_i + batch_size]
# 补全序列
pad_sources_batch = np.array(pad_sentence_batch(sources_batch, source_pad_int))
pad_targets_batch = np.array(pad_sentence_batch(targets_batch, target_pad_int))
# 记录每条记录的长度
targets_lengths = []
for target in targets_batch:
targets_lengths.append(len(target)) #每个batch中每条序列的长度
source_lengths = []
for source in sources_batch:
source_lengths.append(len(source)) #每个batch中每条序列的长度
#yield 是一个类似 return 的关键字,迭代一次遇到yield时就返回yield后面(右边)的值
yield pad_targets_batch, pad_sources_batch, targets_lengths, source_lengths
Train训练
# 将数据集分割为train和validation
train_source = source_int[batch_size:]
train_target = target_int[batch_size:]
# 留出一个batch进行验证
valid_source = source_int[:batch_size]
valid_target = target_int[:batch_size]
(valid_targets_batch, valid_sources_batch, valid_targets_lengths, valid_sources_lengths) = next(get_batches(valid_target, valid_source, batch_size,
source_letter_to_int['<PAD>'],
target_letter_to_int['<PAD>']))
display_step = 50 # 每隔50轮输出loss
checkpoint = "trained_model.ckpt"
with tf.Session(graph=train_graph) as sess:
sess.run(tf.global_variables_initializer())
for epoch_i in range(1, epochs+1):
for batch_i, (targets_batch, sources_batch, targets_lengths, sources_lengths) in enumerate(
get_batches(train_target, train_source, batch_size,
source_letter_to_int['<PAD>'],
target_letter_to_int['<PAD>'])):
_, loss = sess.run(
[train_op, cost],
{input_data: sources_batch,
targets: targets_batch,
lr: learning_rate,
target_sequence_length: targets_lengths,
source_sequence_length: sources_lengths})
if batch_i % display_step == 0:
# 计算validation loss
validation_loss = sess.run(
[cost],
{input_data: valid_sources_batch,
targets: valid_targets_batch,
lr: learning_rate,
target_sequence_length: valid_targets_lengths,
source_sequence_length: valid_sources_lengths})
print('Epoch {:>3}/{} Batch {:>4}/{} - Training Loss: {:>6.3f} - Validation loss: {:>6.3f}'
.format(epoch_i,
epochs,
batch_i,
len(train_source) // batch_size,
loss,
validation_loss[0]))
# 保存模型
saver = tf.train.Saver()
saver.save(sess, checkpoint)
print('Model Trained and Saved')
代码运行结果:
Epoch 1/60 Batch 0/77 - Training Loss: 3.401 - Validation loss: 3.398
Epoch 1/60 Batch 50/77 - Training Loss: 2.781 - Validation loss: 2.785
Epoch 2/60 Batch 0/77 - Training Loss: 2.449 - Validation loss: 2.467
Epoch 2/60 Batch 50/77 - Training Loss: 2.122 - Validation loss: 2.125
Epoch 3/60 Batch 0/77 - Training Loss: 1.982 - Validation loss: 1.989
Epoch 3/60 Batch 50/77 - Training Loss: 1.801 - Validation loss: 1.781
Epoch 4/60 Batch 0/77 - Training Loss: 1.680 - Validation loss: 1.669
Epoch 4/60 Batch 50/77 - Training Loss: 1.527 - Validation loss: 1.521
Epoch 5/60 Batch 0/77 - Training Loss: 1.453 - Validation loss: 1.446
Epoch 5/60 Batch 50/77 - Training Loss: 1.350 - Validation loss: 1.334
Epoch 6/60 Batch 0/77 - Training Loss: 1.286 - Validation loss: 1.277
Epoch 6/60 Batch 50/77 - Training Loss: 1.184 - Validation loss: 1.156
Epoch 7/60 Batch 0/77 - Training Loss: 1.136 - Validation loss: 1.089
Epoch 7/60 Batch 50/77 - Training Loss: 1.018 - Validation loss: 0.982
Epoch 8/60 Batch 0/77 - Training Loss: 0.987 - Validation loss: 0.926
Epoch 8/60 Batch 50/77 - Training Loss: 0.889 - Validation loss: 0.841
Epoch 9/60 Batch 0/77 - Training Loss: 0.864 - Validation loss: 0.799
Epoch 9/60 Batch 50/77 - Training Loss: 0.782 - Validation loss: 0.732
Epoch 10/60 Batch 0/77 - Training Loss: 0.765 - Validation loss: 0.698
Epoch 10/60 Batch 50/77 - Training Loss: 0.697 - Validation loss: 0.646
Epoch 11/60 Batch 0/77 - Training Loss: 0.679 - Validation loss: 0.619
Epoch 11/60 Batch 50/77 - Training Loss: 0.618 - Validation loss: 0.573
Epoch 12/60 Batch 0/77 - Training Loss: 0.595 - Validation loss: 0.543
Epoch 12/60 Batch 50/77 - Training Loss: 0.548 - Validation loss: 0.504
Epoch 13/60 Batch 0/77 - Training Loss: 0.513 - Validation loss: 0.477
Epoch 13/60 Batch 50/77 - Training Loss: 0.478 - Validation loss: 0.441
Epoch 14/60 Batch 0/77 - Training Loss: 0.432 - Validation loss: 0.416
Epoch 14/60 Batch 50/77 - Training Loss: 0.408 - Validation loss: 0.378
Epoch 15/60 Batch 0/77 - Training Loss: 0.359 - Validation loss: 0.360
Epoch 15/60 Batch 50/77 - Training Loss: 0.343 - Validation loss: 0.324
Epoch 16/60 Batch 0/77 - Training Loss: 0.298 - Validation loss: 0.314
Epoch 16/60 Batch 50/77 - Training Loss: 0.291 - Validation loss: 0.277
Epoch 17/60 Batch 0/77 - Training Loss: 0.247 - Validation loss: 0.268
Epoch 17/60 Batch 50/77 - Training Loss: 0.249 - Validation loss: 0.238
Epoch 18/60 Batch 0/77 - Training Loss: 0.208 - Validation loss: 0.227
Epoch 18/60 Batch 50/77 - Training Loss: 0.215 - Validation loss: 0.205
Epoch 19/60 Batch 0/77 - Training Loss: 0.180 - Validation loss: 0.193
Epoch 19/60 Batch 50/77 - Training Loss: 0.184 - Validation loss: 0.177
Epoch 20/60 Batch 0/77 - Training Loss: 0.153 - Validation loss: 0.166
Epoch 20/60 Batch 50/77 - Training Loss: 0.158 - Validation loss: 0.155
Epoch 21/60 Batch 0/77 - Training Loss: 0.131 - Validation loss: 0.146
Epoch 21/60 Batch 50/77 - Training Loss: 0.137 - Validation loss: 0.137
Epoch 22/60 Batch 0/77 - Training Loss: 0.113 - Validation loss: 0.130
Epoch 22/60 Batch 50/77 - Training Loss: 0.119 - Validation loss: 0.122
Epoch 23/60 Batch 0/77 - Training Loss: 0.100 - Validation loss: 0.117
Epoch 23/60 Batch 50/77 - Training Loss: 0.104 - Validation loss: 0.110
Epoch 24/60 Batch 0/77 - Training Loss: 0.089 - Validation loss: 0.106
Epoch 24/60 Batch 50/77 - Training Loss: 0.093 - Validation loss: 0.103
Epoch 25/60 Batch 0/77 - Training Loss: 0.079 - Validation loss: 0.095
Epoch 25/60 Batch 50/77 - Training Loss: 0.084 - Validation loss: 0.098
Epoch 26/60 Batch 0/77 - Training Loss: 0.070 - Validation loss: 0.084
Epoch 26/60 Batch 50/77 - Training Loss: 0.081 - Validation loss: 0.086
Epoch 27/60 Batch 0/77 - Training Loss: 0.064 - Validation loss: 0.078
Epoch 27/60 Batch 50/77 - Training Loss: 0.071 - Validation loss: 0.078
Epoch 28/60 Batch 0/77 - Training Loss: 0.062 - Validation loss: 0.071
Epoch 28/60 Batch 50/77 - Training Loss: 0.061 - Validation loss: 0.071
Epoch 29/60 Batch 0/77 - Training Loss: 0.059 - Validation loss: 0.064
Epoch 29/60 Batch 50/77 - Training Loss: 0.054 - Validation loss: 0.064
Epoch 30/60 Batch 0/77 - Training Loss: 0.053 - Validation loss: 0.061
Epoch 30/60 Batch 50/77 - Training Loss: 0.049 - Validation loss: 0.059
Epoch 31/60 Batch 0/77 - Training Loss: 0.045 - Validation loss: 0.058
Epoch 31/60 Batch 50/77 - Training Loss: 0.047 - Validation loss: 0.054
Epoch 32/60 Batch 0/77 - Training Loss: 0.040 - Validation loss: 0.055
Epoch 32/60 Batch 50/77 - Training Loss: 0.043 - Validation loss: 0.050
Epoch 33/60 Batch 0/77 - Training Loss: 0.034 - Validation loss: 0.046
Epoch 33/60 Batch 50/77 - Training Loss: 0.036 - Validation loss: 0.047
Epoch 34/60 Batch 0/77 - Training Loss: 0.032 - Validation loss: 0.042
Epoch 34/60 Batch 50/77 - Training Loss: 0.032 - Validation loss: 0.045
Epoch 35/60 Batch 0/77 - Training Loss: 0.031 - Validation loss: 0.041
Epoch 35/60 Batch 50/77 - Training Loss: 0.029 - Validation loss: 0.043
Epoch 36/60 Batch 0/77 - Training Loss: 0.027 - Validation loss: 0.039
Epoch 36/60 Batch 50/77 - Training Loss: 0.027 - Validation loss: 0.040
Epoch 37/60 Batch 0/77 - Training Loss: 0.024 - Validation loss: 0.037
Epoch 37/60 Batch 50/77 - Training Loss: 0.025 - Validation loss: 0.037
Epoch 38/60 Batch 0/77 - Training Loss: 0.022 - Validation loss: 0.033
Epoch 38/60 Batch 50/77 - Training Loss: 0.024 - Validation loss: 0.035
Epoch 39/60 Batch 0/77 - Training Loss: 0.020 - Validation loss: 0.029
Epoch 39/60 Batch 50/77 - Training Loss: 0.022 - Validation loss: 0.033
Epoch 40/60 Batch 0/77 - Training Loss: 0.018 - Validation loss: 0.027
Epoch 40/60 Batch 50/77 - Training Loss: 0.020 - Validation loss: 0.031
Epoch 41/60 Batch 0/77 - Training Loss: 0.017 - Validation loss: 0.026
Epoch 41/60 Batch 50/77 - Training Loss: 0.018 - Validation loss: 0.029
Epoch 42/60 Batch 0/77 - Training Loss: 0.016 - Validation loss: 0.025
Epoch 42/60 Batch 50/77 - Training Loss: 0.017 - Validation loss: 0.028
Epoch 43/60 Batch 0/77 - Training Loss: 0.015 - Validation loss: 0.024
Epoch 43/60 Batch 50/77 - Training Loss: 0.016 - Validation loss: 0.027
Epoch 44/60 Batch 0/77 - Training Loss: 0.014 - Validation loss: 0.023
Epoch 44/60 Batch 50/77 - Training Loss: 0.015 - Validation loss: 0.027
Epoch 45/60 Batch 0/77 - Training Loss: 0.013 - Validation loss: 0.022
Epoch 45/60 Batch 50/77 - Training Loss: 0.014 - Validation loss: 0.026
Epoch 46/60 Batch 0/77 - Training Loss: 0.012 - Validation loss: 0.021
Epoch 46/60 Batch 50/77 - Training Loss: 0.013 - Validation loss: 0.025
Epoch 47/60 Batch 0/77 - Training Loss: 0.011 - Validation loss: 0.020
Epoch 47/60 Batch 50/77 - Training Loss: 0.012 - Validation loss: 0.024
Epoch 48/60 Batch 0/77 - Training Loss: 0.011 - Validation loss: 0.019
Epoch 48/60 Batch 50/77 - Training Loss: 0.012 - Validation loss: 0.023
Epoch 49/60 Batch 0/77 - Training Loss: 0.010 - Validation loss: 0.019
Epoch 49/60 Batch 50/77 - Training Loss: 0.011 - Validation loss: 0.022
Epoch 50/60 Batch 0/77 - Training Loss: 0.010 - Validation loss: 0.018
Epoch 50/60 Batch 50/77 - Training Loss: 0.011 - Validation loss: 0.021
Epoch 51/60 Batch 0/77 - Training Loss: 0.010 - Validation loss: 0.018
Epoch 51/60 Batch 50/77 - Training Loss: 0.011 - Validation loss: 0.020
Epoch 52/60 Batch 0/77 - Training Loss: 0.010 - Validation loss: 0.018
Epoch 52/60 Batch 50/77 - Training Loss: 0.010 - Validation loss: 0.020
Epoch 53/60 Batch 0/77 - Training Loss: 0.010 - Validation loss: 0.018
Epoch 53/60 Batch 50/77 - Training Loss: 0.010 - Validation loss: 0.020
Epoch 54/60 Batch 0/77 - Training Loss: 0.009 - Validation loss: 0.019
Epoch 54/60 Batch 50/77 - Training Loss: 0.010 - Validation loss: 0.020
Epoch 55/60 Batch 0/77 - Training Loss: 0.008 - Validation loss: 0.020
Epoch 55/60 Batch 50/77 - Training Loss: 0.009 - Validation loss: 0.021
Epoch 56/60 Batch 0/77 - Training Loss: 0.007 - Validation loss: 0.019
Epoch 56/60 Batch 50/77 - Training Loss: 0.007 - Validation loss: 0.017
Epoch 57/60 Batch 0/77 - Training Loss: 0.006 - Validation loss: 0.016
Epoch 57/60 Batch 50/77 - Training Loss: 0.006 - Validation loss: 0.016
Epoch 58/60 Batch 0/77 - Training Loss: 0.006 - Validation loss: 0.016
Epoch 58/60 Batch 50/77 - Training Loss: 0.006 - Validation loss: 0.015
Epoch 59/60 Batch 0/77 - Training Loss: 0.006 - Validation loss: 0.016
Epoch 59/60 Batch 50/77 - Training Loss: 0.006 - Validation loss: 0.015
Epoch 60/60 Batch 0/77 - Training Loss: 0.005 - Validation loss: 0.015
Epoch 60/60 Batch 50/77 - Training Loss: 0.005 - Validation loss: 0.014
Model Trained and Saved
预测
def source_to_seq(text):
'''
对源数据进行转换
'''
sequence_length = 7
return [source_letter_to_int.get(word, source_letter_to_int['<UNK>']) for word in text] + [source_letter_to_int['<PAD>']]*(sequence_length-len(text))
# 输入一个单词
input_word = 'comoacxz'
text = source_to_seq(input_word)
checkpoint = "./trained_model.ckpt"
loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
# 加载模型
loader = tf.train.import_meta_graph(checkpoint + '.meta')
loader.restore(sess, checkpoint)
input_data = loaded_graph.get_tensor_by_name('inputs:0')
logits = loaded_graph.get_tensor_by_name('predictions:0')
source_sequence_length = loaded_graph.get_tensor_by_name('source_sequence_length:0')
target_sequence_length = loaded_graph.get_tensor_by_name('target_sequence_length:0')
answer_logits = sess.run(logits, {input_data: [text]*batch_size,
target_sequence_length: [len(input_word)]*batch_size,
source_sequence_length: [len(input_word)]*batch_size})[0]
pad = source_letter_to_int["<PAD>"]
print('原始输入:', input_word)
print('\nSource')
print(' Word 编号: {}'.format([i for i in text]))
print(' Input Words: {}'.format(" ".join([source_int_to_letter[i] for i in text])))
print('\nTarget')
print(' Word 编号: {}'.format([i for i in answer_logits if i != pad]))
print(' Response Words: {}'.format(" ".join([target_int_to_letter[i] for i in answer_logits if i != pad])))
代码运行结果:
原始输入: comoacxz
Source
Word 编号: [20, 29, 8, 29, 13, 20, 11, 21]
Input Words: c o m o a c x z
Target
Word 编号: [13, 20, 20, 8, 29, 29, 11, 21]
Response Words: a c c m o o x z