python程序写诗_tensorflow自动写诗

#-*- coding: utf-8 -*-#file: tang_poems.py

importcollectionsimportosimportsysimportnumpy as npimporttensorflow as tffrom models.model importrnn_modelfrom dataset.poems importprocess_poems, generate_batchimportheapq

tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')

tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')#set this to 'main.py' relative path

tf.app.flags.DEFINE_string('checkpoints_dir', os.path.abspath('./checkpoints/poems/'), 'checkpoints save path.')

tf.app.flags.DEFINE_string('file_path', os.path.abspath('./dataset/data/poems.txt'), 'file name of poems.')

tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')

tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')

FLAGS=tf.app.flags.FLAGS

start_token= 'G'end_token= 'E'

defrun_training():#模型保存路径配置

if notos.path.exists(os.path.dirname(FLAGS.checkpoints_dir)):

os.mkdir(os.path.dirname(FLAGS.checkpoints_dir))if notos.path.exists(FLAGS.checkpoints_dir):

os.mkdir(FLAGS.checkpoints_dir)#1、诗集数据处理

poems_vector, word_to_int, vocabularies =process_poems(FLAGS.file_path)#2、生成批量数据用于训练

batches_inputs, batches_outputs =generate_batch(FLAGS.batch_size, poems_vector, word_to_int)

input_data=tf.placeholder(tf.int32, [FLAGS.batch_size, None])

output_targets=tf.placeholder(tf.int32, [FLAGS.batch_size, None])#3、建立模型

end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(

vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)

saver=tf.train.Saver(tf.global_variables())

init_op=tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())#4、开始训练

with tf.Session() as sess:#sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)

#sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

sess.run(init_op)

start_epoch=0

checkpoint=tf.train.latest_checkpoint(FLAGS.checkpoints_dir)ifcheckpoint:

saver.restore(sess, checkpoint)print("[INFO] restore from the checkpoint {0}".format(checkpoint))

start_epoch+= int(checkpoint.split('-')[-1])print('[INFO] start training...')try:for epoch inrange(start_epoch, FLAGS.epochs):

n=0

n_chunk= len(poems_vector) //FLAGS.batch_sizefor batch inrange(n_chunk):

loss, _, _=sess.run([

end_points['total_loss'],

end_points['last_state'],

end_points['train_op']

], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})

n+= 1

print('[INFO] Epoch: %d , batch: %d , training loss: %.6f' %(epoch, batch, loss))if epoch % 6 ==0:

saver.save(sess,'./model/', global_step=epoch)#saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)

exceptKeyboardInterrupt:print('[INFO] Interrupt manually, try saving checkpoint for now...')

saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)print('[INFO] Last epoch were saved, next time will start from epoch {}.'.format(epoch))defto_word(predict, vocabs):

t=np.cumsum(predict)

s=np.sum(predict)

sample= int(np.searchsorted(t, np.random.rand(1) *s))if sample >len(vocabs):

sample= len(vocabs) - 1

returnvocabs[sample]defgen_poem(begin_word):

batch_size= 1

print('[INFO] loading corpus from %s' %FLAGS.file_path)

poems_vector, word_int_map, vocabularies=process_poems(FLAGS.file_path)

input_data=tf.placeholder(tf.int32, [batch_size, None])

end_points= rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(

vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)

saver=tf.train.Saver(tf.global_variables())

init_op=tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

with tf.Session() as sess:

sess.run(init_op)#checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)

checkpoint = tf.train.latest_checkpoint('./model/')#saver.restore(sess, checkpoint)

saver.restore(sess, './model/-24')

x=np.array([list(map(word_int_map.get, start_token))])

[predict, last_state]= sess.run([end_points['prediction'], end_points['last_state']],

feed_dict={input_data: x})ifbegin_word:

word=begin_wordelse:

word=to_word(predict, vocabularies)

poem= ''

while word !=end_token:print ('runing')

poem+=word

x= np.zeros((1, 1))

x[0, 0]=word_int_map[word]

[predict, last_state]= sess.run([end_points['prediction'], end_points['last_state']],

feed_dict={input_data: x, end_points['initial_state']: last_state})

word=to_word(predict, vocabularies)#word = words[np.argmax(probs_)]

returnpoemdefpretty_print_poem(poem):

poem_sentences= poem.split('。')for s inpoem_sentences:if s != '' and len(s) > 10:print(s + '。')defmain(is_train):ifis_train:print('[INFO] train tang poem...')

run_training()else:print('[INFO] write tang poem...')

begin_word= input('输入起始字:')#begin_word = '我'

poem2 =gen_poem(begin_word)

pretty_print_poem(poem2)if __name__ == '__main__':

tf.app.run()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值