【NLP】文本生成、文本纠错 代码学习记录

seq2seq-based-CGEC:将语法纠错看作是一个从错误句子翻译为正确句子的过程。

1. 使用argparse模块解析命令行参数和选项

1.1 创建一个解析对象

parser = argparse.ArgumentParser()

1.2 向该对象中添加所需得命令行参数和选项,每一个add_argument方法对应一个参数或选项;

parser.add_argument()

1.3 调用parse_args()方法进行解析使用。

parser.parse_args()

1.4 HfArgumentParser是Transformer框架中的命令行解析工,它是ArgumentParser的子类,用于从类对象中创建解析对象。这里利用HfArgumentParser加载用于构建模型、微调模型的参数。

from transformers import HfArgumentParser

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses(args_list)

其中,ModelArguments中包含的是关于模型的属性;DataTrainingArguments中包含的是关于微调数据的属性;Seq2SeqTrainingArguments中包含的是模型微调的参数。

2. 设置随机种子seed

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        
set_seed(training_args.seed)

3. 加载数据

设置一个data_files对象,同时保存train_file,validation_file,test_file三个文本;通过load_json将data_files中的内容保存至datasets中,使用相应的key(train,validation,test)即可得到训练集、验证集、测试集。

datasets={}
data_files = {}
if data_args.train_file is not None:
    data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
    data_files["validation"] = data_args.validation_file
if data_args.test_file is not None:
    data_files["test"] = data_args.test_file
for key in data_files:
    datasets[key] = load_json(data_files[key])
def load_json(file_path):
    results={'summarization':[],'article':[]}
    with open(file_path,encoding='utf-8') as f:
        content=json.load(f)
        for sample in content:
            results['summarization'].append(sample['summarization'])
            results['article'].append(sample['article'])
        results=Dataset.from_dict(results)
    return results

4. 数据预处理

4.1 在数据预处理阶段,我们使用AutoTokenizer.from_pretrained()方法实例化我们的tokenizer。tokenizer既可以对一个句子进行预处理,也可以对一个句子对进行预处理,经过tokenizer预处理后得到的数据满足预训练模型输入格式。
引申:**AutoTokenizer.from_pretrained(pretrained_model_path)BertTokenizer.from_pretrained(pretrained_model_path)**的区别在于,AutoTokenizer是通用封装,它会根据载入的预训练模型来适应对应的分词器;而BertTokenizer是基于WordPiece加载BERT分词器。

from transformers import AutoTokenizer, BertTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

注意:tokenizer并不涉及将词转化为词向量的过程,仅仅是对纯文本进行分词,并添加[MASK]、[SEP]、[CLS]标记,并转换为字典索引。

4.2 定义数据预处理函数preprocess_function,将原始的训练集、验证集和测试集分别转换为模型能接受的输入格式。这里以文本纠错为例,主要处理两个参数,text_column和summary_column分别对应错误句子和正确句子。

def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[summary_column]
    model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
    
    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

4.3 接着使用map函数对datasets中的所有样本进行预处理。

train_dataset = train_dataset.map(
        preprocess_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )
    
eval_dataset = eval_dataset.map(
        preprocess_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )
    
test_dataset = test_dataset.map(
        preprocess_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

4.4 通过DataCollatorForSeq2Seq,构造一个data collator

from transformers import DataCollatorForSeq2Seq

label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8 if training_args.fp16 else None,
)

5. 定义评估函数

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    while '' in preds:
        idx=preds.index('')
        preds[idx]='。'
    return preds, labels
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if data_args.ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
   
    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    scores = rouge.get_scores(decoded_preds, decoded_labels,avg=True)

    for key in scores:
        scores[key] = scores[key]['f']*100
    result = scores

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

6. 模型训练和预测

6.1 通过Seq2SeqTrainer初始化一个训练器trainer

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

6.2 模型训练

if training_args.do_train:
    train_result = trainer.train()
    trainer.save_model()  # Saves the tokenizer too for easy upload

    metrics = train_result.metrics
    max_train_samples = (
        data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
    )
    metrics["train_samples"] = min(max_train_samples, len(train_dataset))

    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

6.3 模型测试

 if training_args.do_predict:
    if training_args.predict_with_generate:
        predictions, labels, metrics = trainer.predict(test_dataset, metric_key_prefix="predict")

        test_preds = tokenizer.batch_decode(
            predictions, skip_special_tokens=True,
        )
        test_preds = ["".join(pred.strip().split()) for pred in test_preds]

        output_test_preds_file = args.predict_file
        with open(output_test_preds_file, "w",encoding='UTF-8') as writer:
            writer.write("\n".join(test_preds))

【说明,本人也是初步分析代码,仅用于记录,帮助自己梳理,以上是我自己理解的过程,如果不对,欢迎批评指正!!大家一起加油哟~】

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

可爱的小张同学

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值