有趣的深度学习——使用TensorFlow 2.0 + RNN 实现一个古体诗生成器

本文介绍了使用TensorFlow 2.0和RNN重写古体诗生成器的过程,包括数据集处理、模型构建、训练与测试。作者详细阐述了数据预处理、构建分词器、模型结构、训练策略以及模型评估方法,并提供了代码示例。
摘要由CSDN通过智能技术生成

一、前言

很早之前,我曾经写过一个古体诗生成器(详情可以戳TensorFlow练手项目二:基于循环神经网络(RNN)的古诗生成器),那个时候用的还是Python 2.7和TensorFlow 1.4。

随着框架的迭代,API 的变更,老项目已经很难无障碍运行起来了。有不少朋友在老项目下提出了各种问题,于是,我就萌生了使用TensorFlow 2.0重写项目的想法。

这不,终于抽空,重写了这个项目。

完整的项目已经放到了GitHub上:

AaronJny/DeepLearningExamples/tf2-rnn-poetry-generator (https://github.com/AaronJny/DeepLearningExamples/tree/master/tf2-rnn-poetry-generator)

先对项目做个简单展示。项目主要包含如下功能:

  • 使用唐诗数据集训练模型。
  • 使用训练好的模型,随机生成一首古体诗。
  • 使用训练好的模型,续写一首古体诗。
  • 使用训练好的模型,随机生成一首藏头诗。

随机生成一首古体诗:

金鹤有僧心,临天寄旧身。
石松惊枕树,红鸟发禅新。
不到风前远,何人怨夕时。
明期多尔处,闲此不依迟。
水泉临鸟声,北去暮空行。
林阁多开雪,楼庭起洞城。
夜来疏竹外,柳鸟暗苔清。
寂寂重阳里,悠悠一钓矶。

续写一首古体诗(以"床前明月光,"为例):

床前明月光,翠席覆银丝。
岁气分龙阁,无人入鸟稀。
圣明无泛物,云庙逐雕旗。
永夜重江望,南风正送君。
床前明月光,清水入寒云。
远景千山雨,萧花入翠微。
影云虚雪润,花影落云斜。
独去江飞夜,谁能作一花。

随机生成一首藏头诗(以"海阔天空"为例):

海口多无定,
阔庭何所难。
天山秋色上,
空石昼尘连。
海庭愁不定,
阔处到南关。
天阙青秋上,
空城雁渐催。

下面开始讲解项目实现过程。

转载请注明来源:https://blog.csdn.net/aaronjny/article/details/103806954

二、数据集处理

跟老项目一样,我们仍然使用四万首唐诗的文本作为训练集(已经上传,可以直接从GitHub上下载)。我们打开文本,看一下数据格式:

在这里插入图片描述

能够看到,文本中每行是一首诗,且使用冒号分割,前面是标题,后面是正文,且诗的长度不一。

我们对数据的处理流程大致如下:

  1. 读取文本,按行切分,构成古诗列表。
  2. 将全角、半角的冒号统一替换成半角的。
  3. 按冒号切分诗的标题和内容,只保留诗的内容。
  4. 考虑到模型的大小,我们只保留内容长度小于一定长度的古诗。
  5. 统计保留的诗中的词频,去掉低频词,构建词汇表。

代码如下:

# -*- coding: utf-8 -*-
# @File    : dataset.py
# @Author  : AaronJny
# @Time    : 2019/12/30
# @Desc    : 构建数据集
from collections import Counter
import math
import numpy as np
import tensorflow as tf
import settings

# 禁用词
disallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZE

# 加载数据集
with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    # 将冒号统一成相同格式
    lines = [line.replace(':', ':') for line in lines]
# 数据集列表
poetry = []
# 逐行处理读取到的数据
for line in lines:
    # 有且只能有一个冒号用来分割标题
    if line.count(':') != 1:
        continue
    # 后半部分不能包含禁止词
    __, last_part = line.split(':')
    ignore_flag = False
    for dis_word in disallowed_words:
        if dis_word in last_part:
            ignore_flag = True
            break
    if ignore_flag:
        continue
    # 长度不能超过最大长度
    if len(last_part) > max_len - 2:
        continue
    poetry.append(last_part.replace('\n', ''))

# 统计词频
counter = Counter()
for line in poetry:
    counter.update(line)
# 过滤掉低频词
_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]

# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)

代码很简单,注释也很清晰,就不一行一行说了。有几点需要注意一下:

  • 我们需要一些特殊字符,以完成特定的功能。这里使用的特殊字符有四个,为’[PAD]’, ‘[UNK]’, ‘[CLS]’, ‘[SEP]’,它们分别代表填充字符、低频词、古诗开始标记、古诗结束标记。
  • 代码中出现了一个类——Tokenizer,这是为了方便我们完成字符转编号、编号转字符、字符串转编号序列、编号序列转字符串等操作而编写的一个辅助类。它的代码也很简单,我们来看一下。
class Tokenizer:
    """
    分词器
    """

    def __init__(self, token_dict):
        # 词->编号的映射
        self.token_dict = token_dict
        # 编号->词的映射
        self.token_dict_rev = {
   value: key 
评论 41
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值