第十二课 tensorflow 使用RNN实现古诗自动生成

上一课中说到RNN的实现原理。这一章,一个古诗生成的demo.

输入

# coding:utf-8
"""
数据输入
"""

import logging
import collections
import json
import numpy as np


class PoemInput(object):

    def __init__(self, poem_file_path, batch_size):
        self._poem_file_path = poem_file_path
        self._batch_size = batch_size
        self._poems = list()
        self._poem_vectors = list()
        self._batch_num = 0
        self._chunk_size = None

        # word:index 组成的字典
        self._word_index_dict = None

        # index:word 组成的字典
        self._index_word_dict = None

    def process(self):
        self._convert_vector()

    def _read_file(self):
        line_no = 0
        with open(self._poem_file_path) as poem_file:
            for line in poem_file:

                line_no += 1

                line = line.decode('utf-8').strip()

                line_infos = line.strip().split(':')

                if len(line_infos) != 2:
                    continue

                title = line_infos[0]
                content = line_infos[1]

                content = content.replace(' ', '')
                if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
                    continue
                if len(content) < 5 or len(content) > 79:
                    continue
                content = '[' + content + ']'

                # logging.debug(str(line_no) + ': ' + title + ': ' + content)

                self._poems.append(content)

    def _convert_vector(self):
        self._read_file()

        # 按诗的字数排序
        poetrys = sorted(self._poems, key=lambda line: len(line))
        logging.info(u'唐诗总数: ' + str(len(poetrys)))

        # 统计每个字出现次数
        all_words = []
        for poetry in poetrys:
            all_words += [word for word in poetry]

        counter = collections.Counter(all_words)
        count_pairs = sorted(counter.items(), key=lambda x: -x[1])

        logging.debug('type count_pairs: ' + str(type(count_pairs)))

        for i in range(5):
            logging.debug('count_pairs ' + str(i) + ': ' + str(count_pairs[i]).decode('utf-8'))

        words, counts = zip(*count_pairs)

        logging.debug('type words: ' + str(type(words)))
        logging.debug('indexs:' + str(counts))
        logging.debug('words: ' + ' '.join(words))

        # 取前多少个常用字
        words = words[:len(words)] + (' ',)

        # 每个字映射为一个数字ID
        self._word_index_dict = dict(zip(words, range(len(words))))

        # 将ID 映射成汉字
        self._index_word_dict = dict(zip(range(len(words)), words))

        logging.debug('index_word_dict: ' + json.dumps(self._index_word_dict, ensure_ascii=False))
        logging.debug('word_index_map: ' + json.dumps(self._word_index_dict, ensure_ascii=False))

        # 将每一个词转换成编号
        to_num = lambda word: self._word_index_dict.get(word, len(words))

        # 遍历所有的唐诗,将每一首诗都转换成编号
        self._poem_vectors = [list(map(to_num, poetry)) for poetry in poetrys]

        logging.debug(u'potrys_vector: ' + str(self._poem_vectors))
        logging.debug(u'poetry: ' + json.dumps(poetrys, ensure_ascii=False))

        self._chunk_size = len(self._poem_vectors) // self._batch_size

    def next_batch(self):

        batch_examples = list()

        for i in range(self._batch_size):
            poem = self._poem_vectors[i + self._batch_num * self._batch_size]

            batch_examples.append(poem)

        self._batch_num += 1

        if self._batch_num == self._chunk_size:  # 循环处理
            self._batch_num = 0

        # 将batch_examples 转换成numpy的形式

        # 计算长度取最长的
        poem_max_len = max(map(len, batch_examples))
        x_data = np.full([self._batch_size, poem_max_len], self._word_index_dict[' '], dtype=np.int32)

        for row in range(self._batch_size):
            x_data[row, :len(batch_examples[row])] = batch_examples[row]

        # label 就是讲x_data向后移动一个
        # 产生的每一个batch,就是二维矩阵,行是 batch_size, 每一行都是 一首诗
        # 样本是 a, b, c, d label: b, c, d, d  注意label的最后一个是d 特别注意
        # 另外因为样本是[, a, b, c, d, ] 所以 label就是 a, b, c, d, ], ]
        # 所以 label的最后是 ], 预测出来也是 ] 这很合理 非常合理

        y_data = x_data.copy()

        y_data[:, :-1] = x_data[:, 1:]

        return x_data, y_data

    def convert_poem_vector_2_poem(self, poem_vector):
        """
        把诗歌的index vector转换成具体的诗
        :param poem_vector: index的向量
        :return: 诗
        """

        return [self._index_word_dict[word_index] for word_index in poem_vector]

    def convert_poem_2_poem_vector(self, poem):

        return [self._word_index_dict[word] for word in poem]

    @property
    def word_dict(self):
        return self._word_index_dict

    @property
    def index_dict(self):
        return self._index_word_dict

产生RNN模型

# coding:utf-8
"""
模型产生
"""

import tensorflow as tf
import logging
import common


class Inference(object):
    BASIC_RNN_CELL = 'basice_rnn_cell'
    LSTM_BASIC_CELL = 'lstm_basic_cell'
    GRU_CELL = 'gru_cell'

    def __init__(self, hidden_unit_size, num_layers, class_num, batch_size):
        """
        初始化rnn
        :param hidden_unit_size: 隐层单元数
        :param num_layers: 多少层rnn
        :param class_num: 最终的分类结果,也就是所有单词的总量,每一个单词作为一个分类
        """

        self._hidden_unit_size = hidden_unit_size
        self._num_layers = num_layers
        self._class_num = class_num
        self._batch_size = batch_size

    def inference(self, model_type, inputs, targets):

        # targets 因为是二维的,所以需要转换成一维的来处理
        if targets is not None:
            tf.logging.info('1 targes shape: %s' % (str(targets.shape)))
            targets = tf.reshape(targets, [-1])
            tf.logging.info('2 targes shape: %s' % (str(targets.shape)))

        if model_type == Inference.BASIC_RNN_CELL:
            cell = tf.nn.rnn_cell.BasicRNNCell(self._hidden_unit_size)
        elif model_type == Inference.LSTM_BASIC_CELL:
            cell = tf.nn.rnn_cell.BasicLSTMCell(self._hidden_unit_size)
        elif model_type == Inference.GRU_CELL:
            cell = tf.nn.rnn_cell.GRUCell(self._hidden_unit_size)
        else:
            raise RuntimeError('not exised model type: ' + model_type)

        multi_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self._num_layers)

        initial_state = multi_cell.zero_state(self._batch_size, tf.float32)

        # 构建rnn 模型,需要先将 inputs embedding

        with tf.variable_scope('rnn'):
            softmax_w = tf.get_variable('sfotmax_w', [self._hidden_unit_size, self._class_num], tf.float32)
            softmax_b = tf.get_variable('softmax_b', [self._class_num], tf.float32)

            # 构建inputs的embedding
            with tf.device('/cpu:0'):
                embedding = tf.get_variable('embedding', [self._class_num, self._hidden_unit_size])
                inputs_embeding = tf.nn.embedding_lookup(embedding, inputs)

        outputs, last_state = tf.nn.dynamic_rnn(multi_cell, inputs_embeding, initial_state=initial_state,
                                                scope='rnn', dtype=tf.float32)

        logging.info('outputs shape: ' + str(outputs.shape))
        outputs = tf.reshape(outputs, [-1, self._hidden_unit_size])

        logit = tf.matmul(outputs, softmax_w) + softmax_b
        prob = tf.nn.softmax(logit)

        loss = None
        cost = None

        if targets is not None:
            loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logit],
                                                                      [targets],
                                                                      [tf.ones_like(targets, dtype=tf.float32)])
            cost = tf.reduce_mean(loss)

        return cost, prob, loss, last_state, logit, multi_cell, initial_state

训练

# coding:utf8
"""
训练
"""

import tensorflow as tf
import common
from poem_inference import Inference
from poem_input import PoemInput
import logging


class Train(object):

    def train(self):

        poem_input = PoemInput('./input/poetry.txt', common.TRAIN_BATCH_SIZE)
        poem_input.process()

        num_class = len(poem_input.word_dict) + 1

        x_placeholder = tf.placeholder(tf.int32, [common.TRAIN_BATCH_SIZE, None])
        y_placeholder = tf.placeholder(tf.int32, [common.TRAIN_BATCH_SIZE, None])

        logging.info('y shape 1: ' + str(y_placeholder.shape))

        inference = Inference(common.HIDDEN_UNIT_SIZE,
                              common.NUM_LAYERS,
                              num_class,
                              common.TRAIN_BATCH_SIZE)

        info = inference.inference(Inference.BASIC_RNN_CELL,
                                   x_placeholder,
                                   y_placeholder)

        cost = info[0]

        learning_rate = tf.Variable(0.01, trainable=False)
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        train_op = optimizer.apply_gradients(zip(grads, tvars))

        with tf.Session() as session:

            session.run([tf.global_variables_initializer(),
                         tf.local_variables_initializer()])

            saver = tf.train.Saver(tf.all_variables())

            decay_steps = common.NUM_STPES // 10

            output_steps = common.NUM_STPES // 500

            for step in range(common.NUM_STPES):

                session.run(tf.assign(learning_rate, 0.002 * (0.97 ** (step // decay_steps))))

                batch_x, batch_y = poem_input.next_batch()

                cost_result, _ = session.run([cost, train_op],
                                             feed_dict={
                                                 x_placeholder: batch_x,
                                                 y_placeholder: batch_y
                })

                if step % output_steps == 0:
                    logging.info('step: %d, loss: %f' % (step, cost_result))
                    saver.save(session, './output/poem', global_step=step)

通用配置

# coding:utf-8
"""
common 定义
"""

TRAIN_BATCH_SIZE = 64  # 训练使用的batch size
HIDDEN_UNIT_SIZE = 128
NUM_LAYERS = 2

NUM_STPES = 100000

GEN_BATCH_SIZE = 1  # 生成使用的batch size

训练的main

# coding:utf-8
"""
main入口
"""

from poem_train import Train
import logging


if __name__ == '__main__':
    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s [%(filename)s:%(lineno)d]",
        datefmt='%a, %d %b %Y %H:%M:%S',
        filename='./log/rnn_poem_train.log',
        filemode='w'
    )

    train = Train()
    train.train()

使用训练的模型生成古诗

# coding:utf-8
"""
产生诗歌
"""

import numpy as np
from poem_input import PoemInput
from poem_inference import Inference
import logging
import common
import tensorflow as tf
import json
import sys


class PoemGen(object):

    def __init__(self, poem_path, model_path):
        self._poem_input = PoemInput(poem_path, 1)
        self._prepear()

        num_classes = len(self._poem_input.word_dict) + 1
        self._inference = Inference(common.HIDDEN_UNIT_SIZE,
                                    common.NUM_LAYERS,
                                    num_classes,
                                    common.GEN_BATCH_SIZE)

        self._mode_path = model_path

    def _prepear(self):
        self._poem_input.process()

    def _to_word(self, prob):

        logging.info('prob: ' + str(prob))
        t = np.cumsum(prob)

        logging.info('t: ' + str(t))
        s = np.sum(prob)

        logging.info('s:' + str(s))

        sample = int(np.searchsorted(t, np.random.rand(1) * s))

        logging.info('sample: ' + str(sample))

        if sample not in self._poem_input.index_dict:
            logging.info('index_dict: ' + str(self._poem_input.index_dict))
        return self._poem_input.index_dict[sample]

        # target = prob[0]
        #
        # target_max = 1e-20
        # max_index = -1
        # for index in range(len(target)):
        #     if target[index] > target_max:
        #         target_max = target[index]
        #         max_index = index
        #
        # return self._poem_input.index_dict[max_index]

    def gen(self):
        """
        生成诗歌
        :return:
        """

        word_index_vector = self._poem_input.convert_poem_2_poem_vector(['['])
        x = np.array([word_index_vector])

        logging.debug('x shape: ' + str(x.shape))
        logging.debug('x content: ' + str(x))

        x_placehoder = tf.placeholder(tf.int32, [1, None])
        _, prob, _, last_state, _, _, initial_state = self._inference.inference(Inference.BASIC_RNN_CELL,
                                                                                x_placehoder, None)

        with tf.Session() as session:

            session.run([tf.global_variables_initializer(),
                        tf.local_variables_initializer()])

            saver = tf.train.Saver(tf.all_variables())

            saver.restore(session, self._mode_path)

            prob_result, last_state_result = session.run([prob, last_state], feed_dict={
                x_placehoder: x
            })

            word = self._to_word(prob_result)

            poem = '[' + word
            while word != ']':
                word_index_vector = self._poem_input.convert_poem_2_poem_vector([word])
                x = np.array([word_index_vector])

                prob_result, last_state_result = session.run([prob, last_state], feed_dict={
                    x_placehoder: x,
                    initial_state: last_state_result
                })

                word = self._to_word(prob_result)

                poem += word
                logging.info('poem: ' + json.dumps(poem, ensure_ascii=False).encode('utf-8'))

            return poem


if __name__ == '__main__':

    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s [%(filename)s:%(lineno)d]",
        datefmt='%a, %d %b %Y %H:%M:%S',
        filename='./log/rnn_poem_gen.log',
        filemode='w'
    )

    if len(sys.argv) != 2:
        logging.fatal('please input model path')
        exit(-1)

    poem_gen = PoemGen('./input/poetry.txt')
    poem = poem_gen.gen()
    print('opem: ', json.dumps(poem, ensure_ascii=False).encode('utf-8'))

生成结果

仅仅训练了20000次。可以训练的更久一点,使用的是RNN模型,并没有使用lstm,使用lstm会更好。

金藕耸传木,思天希保官。王化献兹离,连云何近闻。努戈峨轼至,军主就阳川。空使君明愧,囊当白露华。青帘对雪咽,箫雅岂遥游。恭渟万古史,四陆觅其行。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值