基于RNN循环神经网络lstm的藏头诗制作

基于RNN循环神经网络lstm的藏头诗制作


简单介绍

在一次偶然中接触到藏头诗,觉得十分有意思。但是好像都是利用古代本就有的诗句重新组合而成。比如输入清风袭来,结果如下图所示。
屏幕快照 2019-09-12 下午8.56.46.png-14.4kB

之后想到不如利用深度学习制作一个藏头诗,发现github上有学者已经制作了唐诗生成的相关代码。
完整代码地址https://github.com/jinfagang/tensorflow_poems
在此基础上,我对代码进行稍微修改,并进行了注释,希望能帮到对此方面有需求的同学。

模型model.py

import tensorflow as tf
import numpy as np


def rnn_model(model, input_data, output_data, vocab_size, rnn_size=128, num_layers=2, batch_size=64,
              learning_rate=0.01):
    """
    construct rnn seq2seq model.
    :param model: model class
    :param input_data: input data placeholder
    :param output_data: output data placeholder
    :param vocab_size:
    :param rnn_size:
    :param num_layers:
    :param batch_size:
    :param learning_rate:
    :return:
    """
    end_points = {}
    #可以选择rnn的模型
    if model == 'rnn':
        cell_fun = tf.contrib.rnn.BasicRNNCell
    elif model == 'gru':
        cell_fun = tf.contrib.rnn.GRUCell
    elif model == 'lstm':
        cell_fun = tf.contrib.rnn.BasicLSTMCell

    cell = cell_fun(rnn_size, state_is_tuple=True)
    cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

    if output_data is not None:
        initial_state = cell.zero_state(batch_size, tf.float32)
    else:
        initial_state = cell.zero_state(1, tf.float32)

    with tf.device("/cpu:0"):#此处选择用cpu
        embedding = tf.get_variable('embedding', initializer=tf.random_uniform(
                [vocab_size + 1, rnn_size], -1.0, 1.0))

        inputs = tf.nn.embedding_lookup(embedding, input_data)

    # [batch_size, ?, rnn_size] = [64, ?, 128]
    outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
    output = tf.reshape(outputs, [-1, rnn_size])

    weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1]))
    bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
    logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias)
    # [?, vocab_size+1]

    if output_data is not None:
        # output_data must be one-hot encode
        labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1)
        # should be [?, vocab_size+1]

        loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
        # loss shape should be [?, vocab_size+1]
        total_loss = tf.reduce_mean(loss)
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

        end_points['initial_state'] = initial_state
        end_points['output'] = output
        end_points['train_op'] = train_op
        end_points['total_loss'] = total_loss
        end_points['loss'] = loss
        end_points['last_state'] = last_state
    else:
        prediction = tf.nn.softmax(logits)

        end_points['initial_state'] = initial_state
        end_points['last_state'] = last_state
        end_points['prediction'] = prediction

    return end_points

文本处理 poems.py

import collections
import numpy as np

start_token = 'B'
end_token = 'E'


def process_poems(file_name):
    # poems -> list of numbers
    poems = []
    with open(file_name, "r", encoding='utf-8', ) as f:
        for line in f.readlines():
            try:
                title, content = line.strip().split(':')#每一行以:分割,分别赋予title,content
                content = content.replace(' ', '')      #对content处理,以,分割
                if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
                        start_token in content or end_token in content:#去除乱码错误的诗句,以及字数过长或过短的诗句
                    continue
                if len(content) < 5 or len(content) > 79:
                    continue

                content = start_token + content + end_token#形成B content E的形式
                poems.append(content)
            except ValueError as e:
                pass
    # poems = sorted(poems, key=len)

    all_words = [word for poem in poems for word in poem]
    counter = collections.Counter(all_words)
    words = sorted(counter.keys(), key=lambda x: counter[x], reverse=True) #按每一个词出现的频率排序(正序)
    words.append(' ')#末尾加上空格
    L = len(words)
    word_int_map = dict(zip(words, range(L)))#制作字典,每一个字都对应一个数字,频率越高的汉字ID数字越小
    poems_vector = [list(map(lambda word: word_int_map.get(word, L), poem)) for poem in poems]#遍历所有诗句将其转换成数字数组

    return poems_vector, word_int_map, words


def generate_batch(batch_size, poems_vec, word_to_int):

    n_chunk = len(poems_vec) // batch_size
    x_batches = []
    y_batches = []
    for i in range(n_chunk):
        start_index = i * batch_size
        end_index = start_index + batch_size

        batches = poems_vec[start_index:end_index]
        length = max(map(len, batches))#取第一个batch中最大诗句的长度
        x_data = np.full((batch_size, length), word_to_int[' '], np.int32)

        #把第一个batch的所有诗句都转换成数字存储到x_data中
        for row, batch in enumerate(batches):
            x_data[row, :len(batch)] = batch
        #print(x_data.ndim)
        y_data = np.copy(x_data)
        y_data[:, :-1] = x_data[:, 1:]#将y_data向左移一位
        """
        x_data             y_data
        [6,2,4,6,9]       [2,4,6,9,9]
        [1,4,2,8,5]       [4,2,8,5,5]
        """
        x_batches.append(x_data)#将每个batch存入x_batch,y_batch中
        y_batches.append(y_data)
    return x_batches, y_batches

训练 train.py

import os
import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems, generate_batch

tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/qijue-all.txt'), 'file name of poems.')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')

FLAGS = tf.app.flags.FLAGS


def run_training():
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)

    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("## restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## start training...')
        try:
            n_chunk = len(poems_vector) // FLAGS.batch_size
            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0
                for batch in range(n_chunk):
                    loss, _, _ = sess.run([
                        end_points['total_loss'],
                        end_points['last_state'],
                        end_points['train_op']
                    ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
                    n += 1
                    print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
                    #每100步储存loss和batch
                    if batch % 100 == 0:
                        f = open('/users/damon/desktop/tensorflow_poems-master/txt/data100.txt', 'a')
                        f.write(str(epoch) + ',' + str(batch) + ',' + str(loss) + '\n')
                        f.close()
                    #每6割epoch保存一次
                if epoch % 6 == 0:
                    saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
        except KeyboardInterrupt:#人员退出自动保存checkpoint,下次打开可从上次继续训练
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
            print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))


def main(_):
    run_training()


if __name__ == '__main__':
    tf.app.run()

藏头诗生成 compose_poems.py

import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems
import numpy as np

start_token = 'B'
end_token = 'E'
model_dir = './model/'
corpus_file = './data/qijue-all.txt'

lr = 0.0002


def to_word(predict, vocabs):  #预测产生一个汉字
    predict = predict[0]
    predict /= np.sum(predict)
    sample = np.random.choice(np.arange(len(predict)), p=predict)#每一个字被选中的概率是predict,选一次
    if sample > len(vocabs):
        return vocabs[-1]
    else:
        #print(vocabs[sample])
        return vocabs[sample]


def gen_poem(begin_word):
    batch_size = 1
    print('## loading corpus from %s' % model_dir)
    poems_vector, word_int_map, vocabularies = process_poems(corpus_file)
    input_data = tf.placeholder(tf.int32, [batch_size, None])
    end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
            vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)
    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    with tf.Session() as sess:
        sess.run(init_op)

        checkpoint = tf.train.latest_checkpoint(model_dir)
        saver.restore(sess, checkpoint)

        #x = np.array([list(map(word_int_map.get, start_token))])

        #[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                         #feed_dict={input_data: x})

        #word = begin_word or to_word(predict, vocabularies)
        poem_ = ''

        for j in range(len(begin_word)):
            word=begin_word[j]
            while word != end_token:
                poem1 = ''
                x = np.array([list(map(word_int_map.get, start_token))])

                [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                                 feed_dict={input_data: x})
                while len(poem1) < 18:

                    poem1 += word
                    # poem_ += word
                    #i += 1
                    if word == ',':#保证一行诗句中不会出现多个逗号
                        if len(poem1)>9:
                            poem1 = begin_word[j]
                            word = begin_word[j]
                    if word == '。':#保证诗句不会过短
                        if len(poem1) >10:
                            break
                        else:
                            poem1 = begin_word[j]
                            word = begin_word[j]
                    x = np.zeros((1, 1))
                    x[0, 0] = word_int_map[word]
                    x = np.array([[word_int_map[word]]])
                    [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                                     feed_dict={input_data: x, end_points['initial_state']: last_state})
                    word = to_word(predict, vocabularies)#预测下一个字
                poem_ += poem1
                break


        return poem_

#输出生成的藏头诗
def pretty_print_poem(poem_):
    poem_sentences = poem_.split('。')
    k = 0
    for s in poem_sentences:

        if s != '' and len(s) > 10:#去除诗句过短

            if k > len(begin_char)-1:
                break
            if s[0] ==begin_char[k]:#保证诗句第一个字必须是用户输入的汉字
                print(s + '。')
                k +=1


if __name__ == '__main__':
    begin_char = input('## please input the  characters you want to compose:')
    poem = gen_poem(begin_char)
    pretty_print_poem(poem_=poem)

结果

测试结果如下,输入九月十二
屏幕快照 2019-09-12 下午9.36.44.png-18.6kB
输入正在下架
屏幕快照 2019-09-12 下午9.39.54.png-19.2kB

这藏头诗写的是不是有一定水平的哈?虽然我不怎么看懂。不过总而言之我们的藏头诗制作完成啦!

问题分析

有一个问题,训练集不够大,导致在生成藏头诗时容易出现以下问题,输入清风袭来
屏幕快照 2019-09-12 下午9.44.37.png-154.1kB
会出现keyerror,对此我觉的是训练集中这个‘风’字一次都没有出现过,希望有需求的同学可以自行寻找或制作更多的数据集对模型进行训练。

对于训练中生成的txt文件,我们进行画图。

import matplotlib.pyplot as plt
import numpy as np

file_name1='data100.txt'

x=[]
y=[]
with open (file_name1) as file_object:
	lines=file_object.readlines()
	for line in lines:
		line=line.split(',')
		a=int(line[0])*1200
		b=int(line[1])
		x.append(a+b)
		y.append(float(line[2]))
print(np.min(y))
print(y.index(min(y)))
print(x[y.index(min(y))])
#plt.scatter(x,y,color='blue',s=1)
plt.plot(x,y)
plt.show()

Figure_1.png-27.1kB
大概在epoch=26时loss最低。
发现出现了过拟合现象,对此感兴趣的同学可以通过修改网络层数,神经元个数以及学习率等来解决。

  • 1
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值