飞桨2.0应用案例教程 — 用BERT实现自动写诗

用BERT实现自动写诗

作者fiyen

日期:2021.06

摘要:本示例教程将会演示如何使用飞桨2.0以及PaddleNLP快速实现用BERT预训练模型生成高质量诗歌。

摘要

古诗,中华民族最高贵的文化瑰宝,在几千年文化传承中扮演着重要的角色。诗歌已经融入中华儿女的血脉之中,上到古稀之人,下到刚入学的孩童,都能随口吟诵一首诗出来。诗句的运用体现了古今诗人对文字运用的娴熟技艺,同时寄托着诗人深远的情思。诗句或优美或刚劲,或温婉或苍凉,让人在阅读诗歌的时候,如沐春风,身临其境。

美好的诗歌让人心向往之,当我们的眼球接受了美好景物时,谁不曾有“此情此景,我想吟诗一首”的冲动,却限于实力张口息声,半晌想不出一个合适的表达。此时,如果我们有一个强大的诗歌生成工具,岂不美哉?

没问题,通过飞桨,搭建一个古诗自动生成模型将不再是一个困难的事情。在这里,我们将展示如何用飞桨快速搭建一个强大的古诗生成模型。

在这个示例中,我们将快速构建基于BERT预训练模型的古诗生成器,支持诗歌风格定制,以及生成藏头诗。模型基于飞桨2.0框架,BERT预训练模型则调用自PaddleNLP,诗歌数据集采用Github开源数据集。

1. 相关内容介绍

1.1 PaddleNLP

官网链接:https://github.com/PaddlePaddle/PaddleNLP

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NJj22JLN-1635407834342)(https://github.com/fiyen/models/raw/release/2.0-beta/PaddleNLP/docs/imgs/paddlenlp.png)]

PaddleNLP旨在帮助开发者提高文本建模的效率,通过丰富的模型库、简洁易用的API,提供飞桨2.0的最佳实践并加速NLP领域应用产业落地效率。其产品特性如下:

  • 丰富的模型库

涵盖了NLP主流应用相关的前沿模型,包括中文词向量、预训练模型、词法分析、文本分类、文本匹配、文本生成、机器翻译、通用对话、问答系统等。

  • 简洁易用的API

深度兼容飞桨2.0的高层API体系,提供更多可复用的文本建模模块,可大幅度减少数据处理、组网、训练环节的代码开发,提高开发效率。

  • 高性能分布式训练

通过高度优化的Transformer网络实现,结合混合精度与Fleet分布式训练API,可充分利用GPU集群资源,高效完成预训练模型的分布式训练。

1.2 BERT

BERT的全称为Bidirectional Encoder Representations from Transformers,即基于Transformers的双向编码表示模型。BERT是Transformers应用的一次巨大的成功。在该模型提出时,其在NLP领域的11个方向上都大幅刷新了SOTA。其模型的主要特点可以归纳如下:

  1. 基于Transformer。Transformer的提出将注意力机制的应用发挥到了极致,同时也解决了基于RNN的注意力机制的无法并行计算的问题,使超大规模的模型训练在时间上变得可以接受;

  2. 双向编码。其实双向编码不是BERT首创,但是基于Transformer与双向编码结合使这一做法的效用得到了最充分的发挥;

  3. 使用MLM(Mask Language Model)和NSP(Next Sentence Prediction)实现多任务训练的目标。

  4. 迁移学习。BERT模型展现出了大规模数据训练带来的有效性,而更重要的一点是,BERT实质上是一种更好的语义表征,相较于经典的Word2Vec,Glove等模型具有更好词嵌入特征。在实际应用中,我们可以直接调用训练好的BERT模型作为特征表示,进而设计下游任务。

2. 数据设置

在这一部分,我们将介绍使用的数据集,并展示数据集的调用方法。

2.1 数据准备

诗歌数据集采用Github上开源的中华古诗词数据库

该数据集包含了唐宋两朝近一万四千古诗人, 接近5.5万首唐诗加26万宋诗. 两宋时期1564位词人,21050首词。其中,唐宋诗歌内容在json文件夹下,这里只使用json文件夹下的数据即可。以下式单个数据的示例:

{
  "author":string"胡宿"
  "paragraphs":[
    "五粒青松護翠苔,石門岑寂斷纖埃。"
    "水浮花片知仙路,風遞鸞聲認嘯臺。"
    "桐井曉寒千乳斂,茗園春嫩一旗開。"
    "馳煙未勒山亭字,可是英靈許再來。"
  ]
  "title":string"沖虛觀"
  "id":string"dad91d22-4b8a-4c04-a0d5-8f7ca8aff4de"
}
,…]

可见,此数据集中多数诗歌内容为繁体字。不过不用担心,飞桨已经内置了该数据集并且已经进行了简体化,我们可以通过简单的几行代码迅速调用该数据集!如下所示:

# 更新paddlenlp版本
!pip install --upgrade paddlenlp
import paddlenlp
test_dataset, dev_dataset, train_dataset = paddlenlp.datasets.load_dataset('poetry', splits=('test','dev','train'), lazy=False)
print('test_dataset 的样本数量:%d'%len(test_dataset))
print('dev_dataset 的样本数量:%d'%len(dev_dataset))
print('train_dataset 的样本数量:%d'%len(train_dataset))
test_dataset 的样本数量:364
dev_dataset 的样本数量:995
train_dataset 的样本数量:294598

以上三个数据,train_dataset为训练集,test_dataset为测试集,dev_dataset为开发集。其中开发集用于训练过程的测试,以用来选择最合适的模型参数,避免模型过拟合。

2.2 数据处理

如下为以上数据单样本的实例:

print('单样本示例:%s'%test_dataset[0])
单样本示例:{'tokens': '西\x02风\x02簇\x02浪\x02花\x02,\x02太\x02湖\x02连\x02底\x02冻\x02。', 'labels': '冷\x02照\x02玉\x02奁\x02清\x02,\x02一\x02片\x02无\x02瑕\x02缝\x02。\x02面\x02目\x02分\x02明\x02,\x02眼\x02睛\x02定\x02动\x02。\x02不\x02墯\x02虚\x02凝\x02裂\x02万\x02差\x02,\x02漆\x02桶\x02漆\x02桶\x02。'}

从单个样本的实例中可以看到,每个样本都有两句。为了方便处理,这里我们直接将两句合成一句进行训练。训练中我们将用每个诗句当前的字去预测下一个字,假设我们有样本sample, 那么我们的输入为sample[:-1],要预测的目标为sample[1:]。诗句中每个字后边都有符号’\x02’,由于对当前的训练并没有帮助,所以我们将其替换掉。

import re
def data_preprocess(dataset):
    for i, data in enumerate(dataset):
        dataset.data[i] = ''.join(list(dataset[i].values()))
        dataset.data[i] = re.sub('\x02', '', dataset[i])
    return dataset
# 开始处理
test_dataset = data_preprocess(test_dataset)
dev_dataset = data_preprocess(dev_dataset)
train_dataset = data_preprocess(train_dataset)
print('处理后的单样本示例:%s'%test_dataset[0])
处理后的单样本示例:西风簇浪花,太湖连底冻。冷照玉奁清,一片无瑕缝。面目分明,眼睛定动。不墯虚凝裂万差,漆桶漆桶。

从PaddleNLP调用基于BERT预训练模型的分词工具,对诗歌进行分词和编码。

from paddlenlp.transformers import BertTokenizer

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
[2021-06-09 15:35:07,276] [    INFO] - Downloading bert-base-chinese-vocab.txt from https://paddle-hapi.bj.bcebos.com/models/bert/bert-base-chinese-vocab.txt
100%|██████████| 107/107 [00:00<00:00, 23081.18it/s]

处理效果如下。从结果可以看出,分词工具会在诗歌开始添加“[CLS]”标记(“[CLS]”是对一些特殊任务的留空项,对于需要此项功能的并需要标记语句开始的情况,一般会再加上“[BOS]”),在结尾添加“[SEP]”标记(需要区分句子的编码中,这个标记用来将不同的句子隔开,结尾添加“[EOS]”),这些标记在BERT模型训练中扮演者特殊的角色,具有重要的作用。除此之外,也有其他特殊标记,如“[UNK]”表示分词工具无法识别的符号,“[PAD]”表示填充内容的编码。在古诗生成器构造的过程中,我们将针对这些特殊符号进行一些特殊的处理,将这些符号予以剔除。

# 处理效果展示
for poem in test_dataset[0:2]:
    token_poem, _ = bert_tokenizer.encode(poem).values()
    print(poem)
    print(token_poem)
    print(''.join(bert_tokenizer.convert_ids_to_tokens(token_poem)))
西风簇浪花,太湖连底冻。冷照玉奁清,一片无瑕缝。面目分明,眼睛定动。不墯虚凝裂万差,漆桶漆桶。
[101, 6205, 7599, 5077, 3857, 5709, 8024, 1922, 3959, 6825, 2419, 1108, 511, 1107, 4212, 4373, 100, 3926, 8024, 671, 4275, 3187, 4442, 5361, 511, 7481, 4680, 1146, 3209, 8024, 4706, 4714, 2137, 1220, 511, 679, 100, 5994, 1125, 6162, 674, 2345, 8024, 4024, 3446, 4024, 3446, 511, 102]
[CLS]西风簇浪花,太湖连底冻。冷照玉[UNK]清,一片无瑕缝。面目分明,眼睛定动。不[UNK]虚凝裂万差,漆桶漆桶。[SEP]
大道分明在眼前,时人不会悮归泉。黄芽本是乾坤气,神水根基与汞连。
[101, 1920, 6887, 1146, 3209, 1762, 4706, 1184, 8024, 3198, 782, 679, 833, 100, 2495, 3787, 511, 7942, 5715, 3315, 3221, 746, 1787, 3698, 8024, 4868, 3717, 3418, 1825, 680, 3735, 6825, 511, 102]
[CLS]大道分明在眼前,时人不会[UNK]归泉。黄芽本是乾坤气,神水根基与汞连。[SEP]

2.3 构造数据读取器

预处理数据后,我们基于飞桨2.0构造数据读取器,以适应后续模型的训练。

在构造读取器之前,我们先来了解一下BERT模型的输入是什么样子的。如下图所示:

上图中可以清晰地显示出输入数据的具体样式,包括三个部分:Token Embeddings, Segment Embeddings, Position Embeddings。在这里,Embeddings理解为嵌入,即将一个元素表示成一个1 * n的向量的形式,用以表示这个元素在一个向量空间的相对位置。这是中文文本处理如今比较普遍采用的方式。在这里,Token Embeddings为词嵌入,将分词后的词元素映射成一个个1 * n的向量。除此之外,Segment Embeddings表示每个词元素属于何种角色。具体来说,当我们需要区分一个输入中不同语句时,如在对话模型中,区分输入中每一句话是哪个对象发出的,可以用Segment Embeddings。Position Embeddings为Transformer类模型的特色,由于此类自注意力机制无法区分距离的远近,引入了该嵌入来增加距离产生的偏置。通常情况下,Position为一个从句首到句尾渐增的数列,如[0,1,2,3,4,5,…,n-1]即表示一个长度为n的输入的Position。如何得到Embeddings呢?通常是构造一个N * n的矩阵,所有元素被唯一对应一个位置索引,元素数量不大于N。每一个元素的嵌入通过其对应的索引调取矩阵对应的行的n个列上的元素,即1 * n的向量。在这个项目中,由于不需要区分每一句的角色,Segment Embeddings可以设为一样的,即索引都为相同的值 (如0)。由于飞桨的BERT模型会自动处理Segment Embeddings和Position Embeddings,在构造输入的时候我们可以忽略这两项。在进行下一步计算前,所有类型的进行加和,每个词元素对应一个合成的嵌入向量。

需注意以下类定义中包含填充内容,使输入样本对齐到一个特定的长度,以便于模型进行批处理运算。因此在得到数据读取器的实例时,需注意参数max_len,其不超过模型所支持的最大长度(PaddleNLP默认的序列最长长度为512)

import paddle
from paddle.io import Dataset
import numpy as np

class PoemData(Dataset):
    """
    构造诗歌数据集,继承paddle.io.Dataset
    Parameters:
        poems (list): 诗歌数据列表,每一个元素为一首诗歌,诗歌未经编码
        max_len: 接收诗歌的最大长度
    """
    def __init__(self, poems, tokenizer, max_len=128):
        super(PoemData, self).__init__()
        self.poems = poems
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __getitem__(self, idx):
        line = self.poems[idx]
        token_line = self.tokenizer.encode(line)
        token, token_type = token_line['input_ids'], token_line['token_type_ids']
        if len(token) > self.max_len + 1:
            token = token[:self.max_len] + token[-1:]
            token_type = token_type[:self.max_len] + token_type[-1:]
        input_token, input_token_type = token[:-1], token_type[:-1]
        label_token = np.array((token[1:] + [0] * self.max_len)[:self.max_len], dtype='int64')
        # 输入填充
        input_token = np.array((input_token + [0] * self.max_len)[:self.max_len], dtype='int64')
        input_token_type = np.array((input_token_type + [0] * self.max_len)[:self.max_len], dtype='int64')
        input_pad_mask = (input_token != 0).astype('float32')
        return input_token, input_token_type, input_pad_mask, label_token, input_pad_mask
    
    def __len__(self):
        return len(self.poems)

3. 模型设置与训练

在这一部分,我们将快速搭建基于BERT预训练模型的古诗生成器,并对模型进行训练。

3.1 预训练BERT模型

古诗生成是一个文本生成的过程,在实际中模型无法获知还未生成的内容,也即BERT中的双向关系中只能捕捉到前向关系而不能捕捉到后向关系。这个限制我们可以通过添加注意力掩码(attention mask)来屏蔽掉后向的关系,使模型无法注意到还未生成的内容,从而使BERT仍能完成文本生成任务。

进一步地,我们可以将文本生成简化为基于BERT的词分类模型(理解为词性标注),即赋予每个词一个标签,该标签即该词后的下一个词是什么。下表为一个示例:对于诗句“床前明月光,疑是地上霜。”来说,在训练的时候,输入为“床前明月光,疑是地上霜”(注意没有“。”),而预测的内容为输入的每个词对应的标签,我们把其预测标签设置为“前明月光,疑是地上霜。”在这里,我们可以理解为,文字“床”对应的标签为“前”、文字“前”对应的标签为“明”、…、文字“霜”对应的标签为“。”。因此,我们直接调用PaddleNLP的BERT词分类模型即可,需注意模型分类的类别为词表长度。

句子床前明月光,疑是地上霜。
输入床前明月光,疑是地上霜
预测前明月光,疑是地上霜。
流程如下
根据内容:床预测内容:前
根据内容:床前预测内容:明
根据内容:床前明预测内容:月
根据内容:床前明月光,疑是地上霜预测内容:。
from paddlenlp.transformers import BertModel, BertForTokenClassification
from paddle.nn import Layer, Linear, Softmax

class PoetryBertModel(Layer):
    """
    基于BERT预训练模型的诗歌生成模型
    """
    def __init__(self, pretrained_bert_model: str, input_length: int):
        super(PoetryBertModel, self).__init__()
        bert_model = BertModel.from_pretrained(pretrained_bert_model)
        self.vocab_size, self.hidden_size = bert_model.embeddings.word_embeddings.parameters()[0].shape
        self.bert_for_class = BertForTokenClassification(bert_model, self.vocab_size)
        # 生成下三角矩阵,用来mask句子后边的信息
        self.sequence_length = input_length
        # lower_triangle_mask为input_length * input_length的下三角矩阵(包含主对角线),该掩码作为注意力掩码的一部分(在forward的
        # 处理中为0的部分会被处理成无穷小量,以方便在计算注意力权重的时候保证被掩盖的部分权重约等于0)。而之所以写为下三角矩阵的形式,与
        # transformer的多头注意力计算的机制有关,细节可以了解相关论文获悉。
        self.lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))

    def forward(self, token, token_type, input_mask, input_length=None):
        # 计算attention mask
        mask_left = paddle.reshape(input_mask, input_mask.shape + [1])
        mask_right = paddle.reshape(input_mask, [input_mask.shape[0], 1, input_mask.shape[1]])
        # 输入句子中有效的位置
        mask_left = paddle.cast(mask_left, 'float32')
        mask_right = paddle.cast(mask_right, 'float32')
        attention_mask = paddle.matmul(mask_left, mask_right)
        # 注意力机制计算中有效的位置
        if input_length is not None:
            # 之所以要再计算一次,是因为用于推理预测时,可能输入的长度不为实例化时设置的长度。这里的模型在训练时假设输入的
            # 长度是被填充成一致的——这一步不是必须的,但是处理成一致长度比较方便处理(对应地,增加了显存的用度)。
            lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))
        else:
            lower_triangle_mask = self.lower_triangle_mask
        attention_mask = attention_mask * lower_triangle_mask
        # 无效的位置设为极小值
        attention_mask = (1 - paddle.unsqueeze(attention_mask, axis=[1])) * -1e10
        attention_mask = paddle.cast(attention_mask, self.bert_for_class.parameters()[0].dtype)

        output_logits = self.bert_for_class(token, token_type_ids=token_type, attention_mask=attention_mask)
        
        return output_logits

3.2 定义模型损失

由于真实值中有相当一部分是填充内容,我们需重写交叉熵损失,使其忽略填充内容带来的损失。

class PoetryBertModelLossCriterion(Layer):
    def forward(self, pred_logits, label, input_mask):
        loss = paddle.nn.functional.cross_entropy(pred_logits, label, ignore_index=0, reduction='none')
        masked_loss = paddle.mean(loss * input_mask, axis=0)
        return paddle.sum(masked_loss)

3.3 模型准备

针对预训练模型的训练,需使用较小的学习率(learning_rate)进行调优。

from paddle.static import InputSpec
from paddlenlp.metrics import Perplexity
from paddle.optimizer import AdamW

net = PoetryBertModel('bert-base-chinese', 128)

token_ids = InputSpec((-1, 128), 'int64', 'token')
token_type_ids = InputSpec((-1, 128), 'int64', 'token_type')
input_mask = InputSpec((-1, 128), 'float32', 'input_mask')
label = InputSpec((-1, 128), 'int64', 'label')

inputs = [token_ids, token_type_ids, input_mask]
labels = [label, input_mask]

model = paddle.Model(net, inputs, labels)
model.prepare(optimizer=AdamW(learning_rate=0.0001, parameters=model.parameters()), loss=PoetryBertModelLossCriterion(), metrics=[Perplexity()])

model.summary(inputs, [input.dtype for input in inputs])
[2021-06-05 09:16:08,229] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-06-05 09:16:17,229] [    INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/core/fromnumeric.py:87: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


----------------------------------------------------------------------------------------------------------------------------------------
        Layer (type)                                   Input Shape                                 Output Shape            Param #    
========================================================================================================================================
        Embedding-1                                    [[1, 128]]                                  [1, 128, 768]         16,226,304   
        Embedding-2                                    [[1, 128]]                                  [1, 128, 768]           393,216    
        Embedding-3                                    [[1, 128]]                                  [1, 128, 768]            1,536     
        LayerNorm-1                                  [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Dropout-1                                   [[1, 128, 768]]                               [1, 128, 768]              0       
      BertEmbeddings-1                                     []                                      [1, 128, 768]              0       
          Linear-1                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
          Linear-2                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
          Linear-3                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
          Linear-4                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
    MultiHeadAttention-1     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-3                                   [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-2                                  [[1, 128, 768]]                               [1, 128, 768]            1,536     
          Linear-5                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-2                                  [[1, 128, 3072]]                              [1, 128, 3072]              0       
          Linear-6                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-4                                   [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-3                                  [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-1                           [[1, 128, 768]]                               [1, 128, 768]              0       
          Linear-7                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
          Linear-8                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
          Linear-9                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-10                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
    MultiHeadAttention-2     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-6                                   [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-4                                  [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-11                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-5                                  [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-12                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-7                                   [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-5                                  [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-2                           [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-13                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-14                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-15                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-16                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
    MultiHeadAttention-3     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-9                                   [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-6                                  [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-17                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-8                                  [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-18                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-10                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-7                                  [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-3                           [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-19                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-20                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-21                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-22                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
    MultiHeadAttention-4     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-12                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-8                                  [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-23                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-11                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-24                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-13                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-9                                  [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-4                           [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-25                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-26                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-27                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-28                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
    MultiHeadAttention-5     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-15                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-10                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-29                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-14                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-30                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-16                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-11                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-5                           [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-31                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-32                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-33                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-34                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
    MultiHeadAttention-6     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-18                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-12                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-35                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-17                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-36                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-19                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-13                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-6                           [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-37                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-38                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-39                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-40                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
    MultiHeadAttention-7     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-21                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-14                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-41                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-20                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-42                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-22                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-15                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-7                           [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-43                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-44                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-45                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-46                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
    MultiHeadAttention-8     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-24                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-16                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-47                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-23                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-48                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-25                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-17                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-8                           [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-49                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-50                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-51                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-52                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
    MultiHeadAttention-9     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-27                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-18                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-53                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-26                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-54                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-28                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-19                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-9                           [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-55                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-56                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-57                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-58                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-10     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-30                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-20                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-59                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-29                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-60                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-31                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-21                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-10                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-61                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-62                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-63                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-64                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-11     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-33                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-22                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-65                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-32                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-66                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-34                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-23                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-11                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-67                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-68                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-69                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-70                                   [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-12     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
         Dropout-36                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-24                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-71                                   [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
         Dropout-35                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-72                                  [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
         Dropout-37                                  [[1, 128, 768]]                               [1, 128, 768]              0       
        LayerNorm-25                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-12                          [[1, 128, 768]]                               [1, 128, 768]              0       
    TransformerEncoder-1                    [[1, 128, 768], [1, 1, 128, 128]]                      [1, 128, 768]              0       
         Linear-73                                     [[1, 768]]                                    [1, 768]              590,592    
           Tanh-2                                      [[1, 768]]                                    [1, 768]                 0       
        BertPooler-1                                 [[1, 128, 768]]                                 [1, 768]                 0       
        BertModel-1                                    [[1, 128]]                            [[1, 128, 768], [1, 768]]        0       
         Dropout-38                                  [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-74                                   [[1, 128, 768]]                              [1, 128, 21128]        16,247,432   
BertForTokenClassification-1                           [[1, 128]]                                 [1, 128, 21128]             0       
========================================================================================================================================
Total params: 118,515,080
Trainable params: 118,515,080
Non-trainable params: 0
----------------------------------------------------------------------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 219.04
Params size (MB): 452.10
Estimated Total Size (MB): 671.14
----------------------------------------------------------------------------------------------------------------------------------------






{'total_params': 118515080, 'trainable_params': 118515080}

3.4 模型训练

由于调用了预训练模型,再次调优,只需很少轮的训练即可达到较好的效果。

训练过程中,设置save_dir参数来保存训练的模型,并通过save_freq设置保存的频率。

from paddle.io import DataLoader

train_loader = DataLoader(PoemData(train_dataset, bert_tokenizer, 128), batch_size=128, shuffle=True)
dev_loader = DataLoader(PoemData(dev_dataset, bert_tokenizer, 128), batch_size=32, shuffle=True)
model.fit(train_data=train_loader, epochs=10, save_dir='./checkpoint', save_freq=1, verbose=1, eval_data=dev_loader, eval_freq=1)

4. 古诗生成

以下,我们定义一个类来利用已经训练好的模型完成古诗生成的任务。在生成古诗的过程中,我们将已经生成的内容作为输入,编码后输入模型,得到输入中每个词对应的分类结果。然后选取最后一个词的分类结果作为根据当前内容要预测的词。下一轮中,刚刚预测的词将加入到已生成的内容中,继续进行下一个词的预测。

在每轮预测结果的选择中,我们可以使用贪婪的方式选取最优的结果,也可以从前几个较优结果中随机选取(可以得到更多的组合),在这里,用topk进行控制。topk的设置不应太大,否则与随机生成差别不大。

import numpy as np

class PoetryGen(object):
    """
    定义一个自动生成诗句的类,按照要求生成诗句
    model: 训练得到的预测模型
    tokenizer: 分词编码工具
    max_length: 生成诗句的最大长度,需小于等于model所允许的最大长度
    """
    def __init__(self, model, tokenizer, max_length=512):
        self.model = model
        self.tokenizer = tokenizer
        self.puncs = [',', '。', '?', ';']
        self.max_length = max_length

    def generate(self, style='', head='', topk=2):
        """
        根据要求生成诗句
        style (str): 生成诗句的风格,写成诗句的形式,如“大漠孤烟直,长河落日圆。”
        head (str, list): 生成诗句的开头内容。若head为str格式,则head为诗句开始内容;
            若head为list格式,则head中每个元素为对应位置上诗句的开始内容(即藏头诗中的头)。
        topk (int): 从预测的topk中选取结果
        """
        head_index = 0
        style_ids = self.tokenizer.encode(style)['input_ids']
        # 去掉结束标记
        style_ids = style_ids[:-1]
        head_is_list = True if isinstance(head, list) else False
        if head_is_list:
            poetry_ids = self.tokenizer.encode(head[head_index])['input_ids']
        else:
            poetry_ids = self.tokenizer.encode(head)['input_ids']
        # 去掉开始和结束标记
        poetry_ids = poetry_ids[1:-1]
        break_flag = False
        while len(style_ids) + len(poetry_ids) <= self.max_length:
            next_word = self._gen_next_word(style_ids + poetry_ids, topk)
            # 对于一些符号,如[UNK], [PAD], [CLS]等,其产生后对诗句无意义,直接跳过
            if next_word in self.tokenizer.convert_tokens_to_ids(['[UNK]', '[PAD]', '[CLS]']):
                continue
            if head_is_list:
                if next_word in self.tokenizer.convert_tokens_to_ids(self.puncs):
                    head_index += 1
                    if head_index < len(head):
                        new_ids = self.tokenizer.encode(head[head_index])['input_ids']
                        new_ids = [next_word] + new_ids[1:-1]
                    else:
                        new_ids = [next_word]
                        break_flag = True
                else:
                    new_ids = [next_word]
            else:
                new_ids = [next_word]
            if next_word == self.tokenizer.convert_tokens_to_ids(['[SEP]'])[0]:
                break
            poetry_ids += new_ids
            if break_flag:
                break
        return ''.join(self.tokenizer.convert_ids_to_tokens(poetry_ids))

    def _gen_next_word(self, known_ids, topk):
        type_token = [0] * len(known_ids)
        mask = [1] * len(known_ids)
        sequence_length = len(known_ids)
        known_ids = paddle.to_tensor([known_ids], dtype='int64')
        type_token = paddle.to_tensor([type_token], dtype='int64')
        mask = paddle.to_tensor([mask], dtype='float32')
        logits = self.model.network.forward(known_ids, type_token, mask, sequence_length)
        # logits中对应最后一个词的输出即为下一个词的概率
        words_prob = logits[0, -1, :].numpy()
        # 依概率倒序排列后,选取前topk个词
        words_to_be_choosen = words_prob.argsort()[::-1][:topk]
        probs_to_be_choosen = words_prob[words_to_be_choosen]
        # 归一化
        probs_to_be_choosen = probs_to_be_choosen / sum(probs_to_be_choosen)
        word_choosen = np.random.choice(words_to_be_choosen, p=probs_to_be_choosen)
        return word_choosen

4.1 生成古诗示例

# 载入已经训练好的模型
net = PoetryBertModel('bert-base-chinese', 128)
model = paddle.Model(net)
model.load('./checkpoint/final')
poetry_gen = PoetryGen(model, bert_tokenizer)
def poetry_show(poetry):
    pattern = r"([,。;?])"
    text = re.sub(pattern, r'\1 ', poetry)
    for p in text.split():
        if p:
            print(p)
# 随机生成一首诗
poetry = poetry_gen.generate()
poetry_show(poetry)
一雨一晴天气新,
春风桃李不胜春。
山中老去无多事,
莫道山花不是真。
山色不随人意好,
花枝只与鸟情邻。
何时得见东君面,
共醉花光醉一身。
# 生成特定风格的诗
poetry = poetry_gen.generate(style='会当凌绝顶,一览众山小。')
poetry_show(poetry)
云外有时生,
云间无限好?
月明风细细,
松响竹萧悄。
谁识此时情?
相看情未了。
# 生成特定开头的诗
poetry = poetry_gen.generate(head='好好学习')
poetry_show(poetry)
好好学习子,
不如癡爱官。
一身无定价,
百事有馀安。
# 生成藏头诗
poetry = poetry_gen.generate(head=['飞', '桨', '真', '好'])

```python
# 生成特定开头的诗
poetry = poetry_gen.generate(head='好好学习')
poetry_show(poetry)
好好学习子,
不如癡爱官。
一身无定价,
百事有馀安。
# 生成藏头诗
poetry = poetry_gen.generate(head=['飞', '桨', '真', '好'])
poetry_show(poetry)
飞来峰下白莲宫,
桨去帆来一叶东。
真境自然非世外,
好山长与白云通?
  • 2
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值