import pandas as pd
import datasets
import jieba
import numpy as np
import lawrouge
import torch
from datasets import load_dataset, Dataset
from transformers import BertTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from modeling_cpt import CPTForConditionalGeneration
from transformers.utils import logging
# lcsts数据预处理一下
max_input_length = 512
max_target_length = 128
lcsts_part_1=pd.read_table('./SourceDataset/PART_II.txt', header=None,
warn_bad_lines=True, error_bad_lines=False, sep='<[/d|/s|do|su|sh][^a].*>', encoding='utf-8')
lcsts_part_1=lcsts_part_1[0].dropna()
lcsts_part_1=lcsts_part_1.reset_index(drop=True)
lcsts_part_1=pd.concat([lcsts_part_1[1::2].reset_index(drop=True), lcsts_part_1[::2].reset_index(drop=True)], axis=1)
lcsts_part_1.columns=['document', 'summary']
lcsts_part_2=pd.read_table('./SourceDataset/PART_III.txt', header=None,
warn_bad_lines=True, error_bad_lines=False, sep='<[/d|/s|do|su|sh][^a].*>', encoding='utf-8')
lcsts_part_2=lcsts_part_2[0].dropna()
lcsts_part_2=lcsts_part_2.reset_index(drop=True)
lcsts_part_2=pd.concat([lcsts_part_2[1::2].reset_index(drop=True), lcsts_part_2[::2].reset_index(drop=True)], axis=1)
lcsts_part_2.columns=['document', 'summary']
dataset_train = Dataset.from_dict(lcsts_part_1)
dataset_valid = Dataset.from_dict(lcsts_part_2)
TokenModel = "bert-base-chinese"
tokenizer = BertTokenizer.from_pretrained(TokenModel)
def preprocess_function(examples):
inputs = [str(doc) for doc in examples["document"]]
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
inputs = [str(doc) for doc in examples["summary"]]
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(inputs, max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_datasets_t = dataset_train.map(preprocess_function, batched=True)
tokenized_datasets_v = dataset_valid.map(preprocess_function, batched=True)
tokenized_datasets = datasets.DatasetDict({"train":tokenized_datasets_t,"validation": tokenized_datasets_v})
model_checkpoint = "fnlp/cpt-large"
model = CPTForConditionalGeneration.from_pretrained("fnlp/cpt-large")
print(model)
logger = logging.get_logger(__name__)
batch_size = 1
args = Seq2SeqTrainingArguments(
output_dir="results-CPT",
num_train_epochs=2, # demo
do_train=True,
do_eval=True,
per_device_train_batch_size=batch_size, # demo
per_device_eval_batch_size=batch_size,
learning_rate=1e-04,
warmup_steps=500,
weight_decay=0.1,
label_smoothing_factor=0.1,
predict_with_generate=True,
logging_dir="logs",
logging_steps=500,
save_total_limit=3,
# generation_max_length最大生成长度,系统默认20 generation_num_beams=1表示贪心解码,大于1为树搜索
generation_max_length=64,
generation_num_beams=1,
)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
# 中文摘要分数计算
def compute_metrics(eval_pred):
predictions, labels = eval_pred
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
length = len(decoded_labels)
print(len(decoded_labels))
rouge = lawrouge.Rouge()
result = rouge.get_scores(decoded_preds,decoded_labels)
print(result)
rouge_1 = 0
rouge_2 = 0
rouge_l = 0
for r in result:
rouge_1 += r['rouge-1']['f']
rouge_2 += r['rouge-2']['f']
rouge_l += r['rouge-l']['f']
rouge_1 /= length
rouge_2 /= length
rouge_l /= length
result = {'rouge-1':rouge_1,'rouge-2':rouge_2,'rouge-l':rouge_l}
result = {key: value * 100 for key, value in result.items()}
# Add mean generated length
#prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
#result["gen_len"] = np.mean(prediction_lens)
return {k: round(v, 2) for k, v in result.items()}
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
#训练
train_result = trainer.train()
# 建议把训练测试分开,我这瞎写写
# 测试一下
torch.save(model.state_dict(), "CPT-Chinese.pth")
model.load_state_dict(torch.load('./CPT-Chinese.pth'))
def generate_summary(test_samples, model):
inputs = tokenizer(
test_samples,
padding="max_length",
truncation=True,
max_length=max_input_length,
return_tensors="pt",
)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
outputs = model.generate(input_ids, attention_mask=attention_mask,max_length=128)
print(outputs)
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs, output_str
_,x = generate_summary("20日凌晨,寒风刺骨,两名年纪相仿的婴儿相继被狠心的父母遗弃在翔安的两个角落,一个在莲花总院厕所里,一个在东园社区一榕树下。两名婴儿被发现时间相距不过10分钟,莲河边防派出所民警连夜走访,未寻得婴儿家属。目前,一名婴儿已被送往福利院,另一名暂时安置在村民家中。据悉,经医生初步检查,两名婴儿均身体健康,无残疾、无疾病。记者陈佩珊通讯员蔡美娟林才龙",model)
print(x)
print(len(x[0]))
‘’‘
感觉效果不好,还是用BART吧
2 名 婴 儿 寒 风 刺 骨 被 发 现 时 间 相 距 不 过 10 分 钟, 警 察 连 夜 走 访 未 找 到 婴 儿 家 属 ; 婴 儿 已 被 送 福 利 院 。
’‘’
# 验证一下
eval_results = trainer.evaluate()
print(eval_results)
随便写写,肯定一大堆问题。有问题我再改