CPT文本生成

8 篇文章 0 订阅

文本生成

最近在学习文本生成式任务,包括诗词生成,小说创作,摘要生成,title生成等。hugging face中的transformers中提供了T5,MT5,BART,GPT,GPT2等模型方便进行生成式任务。我最近在看了一个关于预训练模型CPT的介绍,也可以进行生成式任务。下面介绍CPT模型以及在在文本生成任务的简单尝试。

CPT模型

Chinese Pre-trained Unbalanced Transformer(CPT)模型是2022年论文CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation提出来的。CPT模型既可以做NLU任务也可以做NLG任务。模型结构如下:
cpt模型结构
CPT模型和transformer 的encoder-decoder结构相似,这里分了三个部分,第一个部分是encoder共享特征S-Enc,第二部分是理解式decoder由self-attention和预训练掩码模型(MLM)组成,记为U-Dec,第三部分式生成式decoder由mask的self-attention以及预训练DAE组成,记为G-Dec。
CPT模型可以看作是两个分离的decoder同时共享encoder的不平衡的Transformer的网络结构,使得模型更能胜任NLG或者NLU任务,同时分离的decoder也能灵活得fine-tuning下游任务。
CPT模型在CLUE benchmarks上的结果如下:
模型对比结果
CPT模型在NER任务上的表现:
在这里插入图片描述

模型应用

根据一段文本描述生成诗句。代码如下:

class CPTModelTextGenerator(nn.Module):
    def __init__(self, pretrain_model_file):
        super(CPTModelTextGenerator, self).__init__()
        self.pretrain_model_file = pretrain_model_file
        self.model = CPTForConditionalGeneration.from_pretrained(self.pretrain_model_file)

    def forward(self, input_ids, attention_mask, decoder_input_ids,
                decoder_attention_mask, labels=None):
        output = self.model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            decoder_input_ids=decoder_input_ids,
                            decoder_attention_mask=decoder_attention_mask,
                            labels=labels)
        logits = output["logits"]
        if labels is not None:
            loss = output["loss"]
        else:
            loss = None
        return logits, loss

    def generator(self, decoder_input_ids, gen_max_seq_length, decoder_start_token_id,
                  num_beams=3, temperature=1., top_k=3, top_p=2, repetition_penalty=2.5,
                  no_repeat_ngram_size=5, encoder_no_repeat_ngram_size=7
                  ):
        output = self.model.generate(input_ids=decoder_input_ids,
                                     max_length=gen_max_seq_length,
                                     early_stopping=True,
                                     num_beams=num_beams,
                                     temperature=temperature,
                                     top_k=top_k,
                                     top_p=top_p,
                                     repetition_penalty=repetition_penalty,
                                     no_repeat_ngram_size=no_repeat_ngram_size,
                                     decoder_start_token_id=decoder_start_token_id,
                                     encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size
                                     )
        return output

模型效果如下:

{
    "text": "杜鹃放弃了繁华的故园山川,年复一年地四处飘荡。在异乡鸣叫,鲜血染红了山上花丛,可春天来到,老花园依然草木茂盛。雨后凉风,它藏在绿树丛中声声哀啼,夜幕初开,它迎着欲曙的天空肃然鸣叫。天色渐晚,它在湘江边凄凉鸣叫,使归家的船只行人悲愁之至。",
    "pred": "杜鹃弃故园,年复辗转漂。啼血染山花,春到老园草。雨后凉风藏,天开夜欲鸣。日暮湘江泣,归舟行人悲。"
  }

以上是介绍内容,如有错误,欢迎指证。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值