【Advanced】(九)、transformers实战文本摘要

1、介绍

 文本摘要任务的输入是长的文本文档,任务目标是将较长的文本转换成简短的摘要,一般来说生成简短的摘要必须要信息量充足,能够覆盖原文的主要内容。

  • 根据输入文档的数量划分,可以将摘要任务划分为单文档和多文档摘要

  • 根据输入和输出的语言划分,可以将摘要任务划分为单语言,跨语言,多语言摘要

评价指标:

rouge

  • rouge-1,rouge-2、rouge-l
  • 分别基于1-gram,2-gram和longest common subsequence

2、代码实战

2.1、导包

import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

2.2、加载数据

ds = Dataset.load_from_disk('./nlpcc_2017/')
ds
ds = ds.train_test_split(100,seed=42,)
ds

2.3、数据处理

tokenizer = AutoTokenizer.from_pretrained('../Model/T5base')
tokenizer
def process_func(examples):
    contents = ['摘要生成:\n'+ e for e in examples['content']]
    inputs = tokenizer(contents,
                        max_length=128, 
                        truncation=True)
    labels = tokenizer(text_target=examples['title'],
                       max_length=32, 
                        truncation=True
                       ) 
    inputs['labels'] = labels['input_ids']
    return inputs
    pass
tokenizered_ds = ds.map(process_func, batched=True)
tokenizered_ds

2.4、创建模型

model = AutoModelForSeq2SeqLM.from_pretrained("../Model/T5base")

2.5、创建评估函数是

import numpy as np
from rouge_chinese import Rouge
rouge = Rouge()

def compute_metric(evalPred):
    predictions ,labels = evalPred
    decode_preds = tokenizer.batch_decode(predictions,skip_special_tokens=True)
    labels = np.where(labels !=-100,labels,tokenizer.pad_token_id)
    decode_labels = tokenizer.batch_decode(labels,skip_special_tokens=True)
    decode_preds = [' '.join(p) for p in decode_preds]
    decode_labels = [' '.join(l) for l in decode_labels]
    scores = rouge.get_scores(decode_preds,decode_labels,avg=True)
    return {
        'rouge-1':scores['rouge-1']['f'],
        'rouge-2':scores['rouge-2']['f'],
        'rouge-l':scores['rouge-l']['f'],

    }

2.6、配置训练参数

args = Seq2SeqTrainingArguments(
    output_dir="./summary",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=4,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    metric_for_best_model="rouge-l",
    predict_with_generate=True#################
)

2.7、创建训练器

trainer = Seq2SeqTrainer(
    args=args,
    model=model,
    train_dataset=tokenizered_ds["train"],
    eval_dataset=tokenizered_ds["test"],
    compute_metrics=compute_metric,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer)
)

2.8、训练

trainer.train()

2.9、模型推理

from transformers import pipeline
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)
pipe("摘要生成:\n" + ds["test"][-1]["content"], max_length=64, do_sample=True)
  • 7
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

鲸可落

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值