基于RNN循环神经网络lstm的藏头诗制作
简单介绍
在一次偶然中接触到藏头诗,觉得十分有意思。但是好像都是利用古代本就有的诗句重新组合而成。比如输入清风袭来,结果如下图所示。
之后想到不如利用深度学习制作一个藏头诗,发现github上有学者已经制作了唐诗生成的相关代码。
完整代码地址https://github.com/jinfagang/tensorflow_poems
在此基础上,我对代码进行稍微修改,并进行了注释,希望能帮到对此方面有需求的同学。
模型model.py
import tensorflow as tf
import numpy as np
def rnn_model(model, input_data, output_data, vocab_size, rnn_size=128, num_layers=2, batch_size=64,
learning_rate=0.01):
"""
construct rnn seq2seq model.
:param model: model class
:param input_data: input data placeholder
:param output_data: output data placeholder
:param vocab_size:
:param rnn_size:
:param num_layers:
:param batch_size:
:param learning_rate:
:return:
"""
end_points = {}
#可以选择rnn的模型
if model == 'rnn':
cell_fun = tf.contrib.rnn.BasicRNNCell
elif model == 'gru':
cell_fun = tf.contrib.rnn.GRUCell
elif model == 'lstm':
cell_fun = tf.contrib.rnn.BasicLSTMCell
cell = cell_fun(rnn_size, state_is_tuple=True)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
if output_data is not None:
initial_state = cell.zero_state(batch_size, tf.float32)
else:
initial_state = cell.zero_state(1, tf.float32)
with tf.device("/cpu:0"):#此处选择用cpu
embedding = tf.get_variable('embedding', initializer=tf.random_uniform(
[vocab_size + 1, rnn_size], -1.0, 1.0))
inputs = tf.nn.embedding_lookup(embedding, input_data)
# [batch_size, ?, rnn_size] = [64, ?, 128]
outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
output = tf.reshape(outputs, [-1, rnn_size])
weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1]))
bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias)
# [?, vocab_size+1]
if output_data is not None:
# output_data must be one-hot encode
labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1)
# should be [?, vocab_size+1]
loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
# loss shape should be [?, vocab_size+1]
total_loss = tf.reduce_mean(loss)
train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
end_points['initial_state'] = initial_state
end_points['output'] = output
end_points['train_op'] = train_op
end_points['total_loss'] = total_loss
end_points['loss'] = loss
end_points['last_state'] = last_state
else:
prediction = tf.nn.softmax(logits)
end_points['initial_state'] = initial_state
end_points['last_state'] = last_state
end_points['prediction'] = prediction
return end_points
文本处理 poems.py
import collections
import numpy as np
start_token = 'B'
end_token = 'E'
def process_poems(file_name):
# poems -> list of numbers
poems = []
with open(file_name, "r", encoding='utf-8', ) as f:
for line in f.readlines():
try:
title, content = line.strip().split(':')#每一行以:分割,分别赋予title,content
content = content.replace(' ', '') #对content处理,以,分割
if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
start_token in content or end_token in content:#去除乱码错误的诗句,以及字数过长或过短的诗句
continue
if len(content) < 5 or len(content) > 79:
continue
content = start_token + content + end_token#形成B content E的形式
poems.append(content)
except ValueError as e:
pass
# poems = sorted(poems, key=len)
all_words = [word for poem in poems for word in poem]
counter = collections.Counter(all_words)
words = sorted(counter.keys(), key=lambda x: counter[x], reverse=True) #按每一个词出现的频率排序(正序)
words.append(' ')#末尾加上空格
L = len(words)
word_int_map = dict(zip(words, range(L)))#制作字典,每一个字都对应一个数字,频率越高的汉字ID数字越小
poems_vector = [list(map(lambda word: word_int_map.get(word, L), poem)) for poem in poems]#遍历所有诗句将其转换成数字数组
return poems_vector, word_int_map, words
def generate_batch(batch_size, poems_vec, word_to_int):
n_chunk = len(poems_vec) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
start_index = i * batch_size
end_index = start_index + batch_size
batches = poems_vec[start_index:end_index]
length = max(map(len, batches))#取第一个batch中最大诗句的长度
x_data = np.full((batch_size, length), word_to_int[' '], np.int32)
#把第一个batch的所有诗句都转换成数字存储到x_data中
for row, batch in enumerate(batches):
x_data[row, :len(batch)] = batch
#print(x_data.ndim)
y_data = np.copy(x_data)
y_data[:, :-1] = x_data[:, 1:]#将y_data向左移一位
"""
x_data y_data
[6,2,4,6,9] [2,4,6,9,9]
[1,4,2,8,5] [4,2,8,5,5]
"""
x_batches.append(x_data)#将每个batch存入x_batch,y_batch中
y_batches.append(y_data)
return x_batches, y_batches
训练 train.py
import os
import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems, generate_batch
tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/qijue-all.txt'), 'file name of poems.')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')
FLAGS = tf.app.flags.FLAGS
def run_training():
if not os.path.exists(FLAGS.model_dir):
os.makedirs(FLAGS.model_dir)
poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
# sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
# sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
sess.run(init_op)
start_epoch = 0
checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
if checkpoint:
saver.restore(sess, checkpoint)
print("## restore from the checkpoint {0}".format(checkpoint))
start_epoch += int(checkpoint.split('-')[-1])
print('## start training...')
try:
n_chunk = len(poems_vector) // FLAGS.batch_size
for epoch in range(start_epoch, FLAGS.epochs):
n = 0
for batch in range(n_chunk):
loss, _, _ = sess.run([
end_points['total_loss'],
end_points['last_state'],
end_points['train_op']
], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
n += 1
print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
#每100步储存loss和batch
if batch % 100 == 0:
f = open('/users/damon/desktop/tensorflow_poems-master/txt/data100.txt', 'a')
f.write(str(epoch) + ',' + str(batch) + ',' + str(loss) + '\n')
f.close()
#每6割epoch保存一次
if epoch % 6 == 0:
saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
except KeyboardInterrupt:#人员退出自动保存checkpoint,下次打开可从上次继续训练
print('## Interrupt manually, try saving checkpoint for now...')
saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
def main(_):
run_training()
if __name__ == '__main__':
tf.app.run()
藏头诗生成 compose_poems.py
import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems
import numpy as np
start_token = 'B'
end_token = 'E'
model_dir = './model/'
corpus_file = './data/qijue-all.txt'
lr = 0.0002
def to_word(predict, vocabs): #预测产生一个汉字
predict = predict[0]
predict /= np.sum(predict)
sample = np.random.choice(np.arange(len(predict)), p=predict)#每一个字被选中的概率是predict,选一次
if sample > len(vocabs):
return vocabs[-1]
else:
#print(vocabs[sample])
return vocabs[sample]
def gen_poem(begin_word):
batch_size = 1
print('## loading corpus from %s' % model_dir)
poems_vector, word_int_map, vocabularies = process_poems(corpus_file)
input_data = tf.placeholder(tf.int32, [batch_size, None])
end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
checkpoint = tf.train.latest_checkpoint(model_dir)
saver.restore(sess, checkpoint)
#x = np.array([list(map(word_int_map.get, start_token))])
#[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
#feed_dict={input_data: x})
#word = begin_word or to_word(predict, vocabularies)
poem_ = ''
for j in range(len(begin_word)):
word=begin_word[j]
while word != end_token:
poem1 = ''
x = np.array([list(map(word_int_map.get, start_token))])
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x})
while len(poem1) < 18:
poem1 += word
# poem_ += word
#i += 1
if word == ',':#保证一行诗句中不会出现多个逗号
if len(poem1)>9:
poem1 = begin_word[j]
word = begin_word[j]
if word == '。':#保证诗句不会过短
if len(poem1) >10:
break
else:
poem1 = begin_word[j]
word = begin_word[j]
x = np.zeros((1, 1))
x[0, 0] = word_int_map[word]
x = np.array([[word_int_map[word]]])
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x, end_points['initial_state']: last_state})
word = to_word(predict, vocabularies)#预测下一个字
poem_ += poem1
break
return poem_
#输出生成的藏头诗
def pretty_print_poem(poem_):
poem_sentences = poem_.split('。')
k = 0
for s in poem_sentences:
if s != '' and len(s) > 10:#去除诗句过短
if k > len(begin_char)-1:
break
if s[0] ==begin_char[k]:#保证诗句第一个字必须是用户输入的汉字
print(s + '。')
k +=1
if __name__ == '__main__':
begin_char = input('## please input the characters you want to compose:')
poem = gen_poem(begin_char)
pretty_print_poem(poem_=poem)
结果
测试结果如下,输入九月十二
输入正在下架
这藏头诗写的是不是有一定水平的哈?虽然我不怎么看懂。不过总而言之我们的藏头诗制作完成啦!
问题分析
有一个问题,训练集不够大,导致在生成藏头诗时容易出现以下问题,输入清风袭来
会出现keyerror,对此我觉的是训练集中这个‘风’字一次都没有出现过,希望有需求的同学可以自行寻找或制作更多的数据集对模型进行训练。
对于训练中生成的txt文件,我们进行画图。
import matplotlib.pyplot as plt
import numpy as np
file_name1='data100.txt'
x=[]
y=[]
with open (file_name1) as file_object:
lines=file_object.readlines()
for line in lines:
line=line.split(',')
a=int(line[0])*1200
b=int(line[1])
x.append(a+b)
y.append(float(line[2]))
print(np.min(y))
print(y.index(min(y)))
print(x[y.index(min(y))])
#plt.scatter(x,y,color='blue',s=1)
plt.plot(x,y)
plt.show()
大概在epoch=26时loss最低。
发现出现了过拟合现象,对此感兴趣的同学可以通过修改网络层数,神经元个数以及学习率等来解决。