#-*- coding: utf-8 -*-#file: tang_poems.py
importcollectionsimportosimportsysimportnumpy as npimporttensorflow as tffrom models.model importrnn_modelfrom dataset.poems importprocess_poems, generate_batchimportheapq
tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')#set this to 'main.py' relative path
tf.app.flags.DEFINE_string('checkpoints_dir', os.path.abspath('./checkpoints/poems/'), 'checkpoints save path.')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./dataset/data/poems.txt'), 'file name of poems.')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')
FLAGS=tf.app.flags.FLAGS
start_token= 'G'end_token= 'E'
defrun_training():#模型保存路径配置
if notos.path.exists(os.path.dirname(FLAGS.checkpoints_dir)):
os.mkdir(os.path.dirname(FLAGS.checkpoints_dir))if notos.path.exists(FLAGS.checkpoints_dir):
os.mkdir(FLAGS.checkpoints_dir)#1、诗集数据处理
poems_vector, word_to_int, vocabularies =process_poems(FLAGS.file_path)#2、生成批量数据用于训练
batches_inputs, batches_outputs =generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
input_data=tf.placeholder(tf.int32, [FLAGS.batch_size, None])
output_targets=tf.placeholder(tf.int32, [FLAGS.batch_size, None])#3、建立模型
end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
saver=tf.train.Saver(tf.global_variables())
init_op=tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())#4、开始训练
with tf.Session() as sess:#sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
#sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
sess.run(init_op)
start_epoch=0
checkpoint=tf.train.latest_checkpoint(FLAGS.checkpoints_dir)ifcheckpoint:
saver.restore(sess, checkpoint)print("[INFO] restore from the checkpoint {0}".format(checkpoint))
start_epoch+= int(checkpoint.split('-')[-1])print('[INFO] start training...')try:for epoch inrange(start_epoch, FLAGS.epochs):
n=0
n_chunk= len(poems_vector) //FLAGS.batch_sizefor batch inrange(n_chunk):
loss, _, _=sess.run([
end_points['total_loss'],
end_points['last_state'],
end_points['train_op']
], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
n+= 1
print('[INFO] Epoch: %d , batch: %d , training loss: %.6f' %(epoch, batch, loss))if epoch % 6 ==0:
saver.save(sess,'./model/', global_step=epoch)#saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)
exceptKeyboardInterrupt:print('[INFO] Interrupt manually, try saving checkpoint for now...')
saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)print('[INFO] Last epoch were saved, next time will start from epoch {}.'.format(epoch))defto_word(predict, vocabs):
t=np.cumsum(predict)
s=np.sum(predict)
sample= int(np.searchsorted(t, np.random.rand(1) *s))if sample >len(vocabs):
sample= len(vocabs) - 1
returnvocabs[sample]defgen_poem(begin_word):
batch_size= 1
print('[INFO] loading corpus from %s' %FLAGS.file_path)
poems_vector, word_int_map, vocabularies=process_poems(FLAGS.file_path)
input_data=tf.placeholder(tf.int32, [batch_size, None])
end_points= rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
saver=tf.train.Saver(tf.global_variables())
init_op=tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)#checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
checkpoint = tf.train.latest_checkpoint('./model/')#saver.restore(sess, checkpoint)
saver.restore(sess, './model/-24')
x=np.array([list(map(word_int_map.get, start_token))])
[predict, last_state]= sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x})ifbegin_word:
word=begin_wordelse:
word=to_word(predict, vocabularies)
poem= ''
while word !=end_token:print ('runing')
poem+=word
x= np.zeros((1, 1))
x[0, 0]=word_int_map[word]
[predict, last_state]= sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x, end_points['initial_state']: last_state})
word=to_word(predict, vocabularies)#word = words[np.argmax(probs_)]
returnpoemdefpretty_print_poem(poem):
poem_sentences= poem.split('。')for s inpoem_sentences:if s != '' and len(s) > 10:print(s + '。')defmain(is_train):ifis_train:print('[INFO] train tang poem...')
run_training()else:print('[INFO] write tang poem...')
begin_word= input('输入起始字:')#begin_word = '我'
poem2 =gen_poem(begin_word)
pretty_print_poem(poem2)if __name__ == '__main__':
tf.app.run()