一、前言
很早之前,我曾经写过一个古体诗生成器(详情可以戳TensorFlow练手项目二:基于循环神经网络(RNN)的古诗生成器),那个时候用的还是Python 2.7和TensorFlow 1.4。
随着框架的迭代,API 的变更,老项目已经很难无障碍运行起来了。有不少朋友在老项目下提出了各种问题,于是,我就萌生了使用TensorFlow 2.0重写项目的想法。
这不,终于抽空,重写了这个项目。
完整的项目已经放到了GitHub上:
AaronJny/DeepLearningExamples/tf2-rnn-poetry-generator (https://github.com/AaronJny/DeepLearningExamples/tree/master/tf2-rnn-poetry-generator)
先对项目做个简单展示。项目主要包含如下功能:
- 使用唐诗数据集训练模型。
- 使用训练好的模型,随机生成一首古体诗。
- 使用训练好的模型,续写一首古体诗。
- 使用训练好的模型,随机生成一首藏头诗。
随机生成一首古体诗:
金鹤有僧心,临天寄旧身。
石松惊枕树,红鸟发禅新。
不到风前远,何人怨夕时。
明期多尔处,闲此不依迟。
水泉临鸟声,北去暮空行。
林阁多开雪,楼庭起洞城。
夜来疏竹外,柳鸟暗苔清。
寂寂重阳里,悠悠一钓矶。
续写一首古体诗(以"床前明月光,"为例):
床前明月光,翠席覆银丝。
岁气分龙阁,无人入鸟稀。
圣明无泛物,云庙逐雕旗。
永夜重江望,南风正送君。
床前明月光,清水入寒云。
远景千山雨,萧花入翠微。
影云虚雪润,花影落云斜。
独去江飞夜,谁能作一花。
随机生成一首藏头诗(以"海阔天空"为例):
海口多无定,
阔庭何所难。
天山秋色上,
空石昼尘连。
海庭愁不定,
阔处到南关。
天阙青秋上,
空城雁渐催。
下面开始讲解项目实现过程。
转载请注明来源:https://blog.csdn.net/aaronjny/article/details/103806954
二、数据集处理
跟老项目一样,我们仍然使用四万首唐诗的文本作为训练集(已经上传,可以直接从GitHub上下载)。我们打开文本,看一下数据格式:
能够看到,文本中每行是一首诗,且使用冒号分割,前面是标题,后面是正文,且诗的长度不一。
我们对数据的处理流程大致如下:
- 读取文本,按行切分,构成古诗列表。
- 将全角、半角的冒号统一替换成半角的。
- 按冒号切分诗的标题和内容,只保留诗的内容。
- 考虑到模型的大小,我们只保留内容长度小于一定长度的古诗。
- 统计保留的诗中的词频,去掉低频词,构建词汇表。
代码如下:
# -*- coding: utf-8 -*-
# @File : dataset.py
# @Author : AaronJny
# @Time : 2019/12/30
# @Desc : 构建数据集
from collections import Counter
import math
import numpy as np
import tensorflow as tf
import settings
# 禁用词
disallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZE
# 加载数据集
with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 将冒号统一成相同格式
lines = [line.replace(':', ':') for line in lines]
# 数据集列表
poetry = []
# 逐行处理读取到的数据
for line in lines:
# 有且只能有一个冒号用来分割标题
if line.count(':') != 1:
continue
# 后半部分不能包含禁止词
__, last_part = line.split(':')
ignore_flag = False
for dis_word in disallowed_words:
if dis_word in last_part:
ignore_flag = True
break
if ignore_flag:
continue
# 长度不能超过最大长度
if len(last_part) > max_len - 2:
continue
poetry.append(last_part.replace('\n', ''))
# 统计词频
counter = Counter()
for line in poetry:
counter.update(line)
# 过滤掉低频词
_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]
# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)
代码很简单,注释也很清晰,就不一行一行说了。有几点需要注意一下:
- 我们需要一些特殊字符,以完成特定的功能。这里使用的特殊字符有四个,为’[PAD]’, ‘[UNK]’, ‘[CLS]’, ‘[SEP]’,它们分别代表填充字符、低频词、古诗开始标记、古诗结束标记。
- 代码中出现了一个类——Tokenizer,这是为了方便我们完成字符转编号、编号转字符、字符串转编号序列、编号序列转字符串等操作而编写的一个辅助类。它的代码也很简单,我们来看一下。
class Tokenizer:
"""
分词器
"""
def __init__(self, token_dict):
# 词->编号的映射
self.token_dict = token_dict
# 编号->词的映射
self.token_dict_rev = {
value: key