seq2seq-based-CGEC:将语法纠错看作是一个从错误句子翻译为正确句子的过程。
1. 使用argparse模块解析命令行参数和选项
1.1 创建一个解析对象
parser = argparse.ArgumentParser()
1.2 向该对象中添加所需得命令行参数和选项,每一个add_argument方法对应一个参数或选项;
parser.add_argument()
1.3 调用parse_args()方法进行解析使用。
parser.parse_args()
1.4 HfArgumentParser是Transformer框架中的命令行解析工,它是ArgumentParser的子类,用于从类对象中创建解析对象。这里利用HfArgumentParser加载用于构建模型、微调模型的参数。
from transformers import HfArgumentParser
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses(args_list)
其中,ModelArguments中包含的是关于模型的属性;DataTrainingArguments中包含的是关于微调数据的属性;Seq2SeqTrainingArguments中包含的是模型微调的参数。
2. 设置随机种子seed
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(training_args.seed)
3. 加载数据
设置一个data_files对象,同时保存train_file,validation_file,test_file三个文本;通过load_json将data_files中的内容保存至datasets中,使用相应的key(train,validation,test)即可得到训练集、验证集、测试集。
datasets={}
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
for key in data_files:
datasets[key] = load_json(data_files[key])
def load_json(file_path):
results={'summarization':[],'article':[]}
with open(file_path,encoding='utf-8') as f:
content=json.load(f)
for sample in content:
results['summarization'].append(sample['summarization'])
results['article'].append(sample['article'])
results=Dataset.from_dict(results)
return results
4. 数据预处理
4.1 在数据预处理阶段,我们使用AutoTokenizer.from_pretrained()方法实例化我们的tokenizer。tokenizer既可以对一个句子进行预处理,也可以对一个句子对进行预处理,经过tokenizer预处理后得到的数据满足预训练模型输入格式。
引申:**AutoTokenizer.from_pretrained(pretrained_model_path)与BertTokenizer.from_pretrained(pretrained_model_path)**的区别在于,AutoTokenizer是通用封装,它会根据载入的预训练模型来适应对应的分词器;而BertTokenizer是基于WordPiece加载BERT分词器。
from transformers import AutoTokenizer, BertTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
注意:tokenizer并不涉及将词转化为词向量的过程,仅仅是对纯文本进行分词,并添加[MASK]、[SEP]、[CLS]标记,并转换为字典索引。
4.2 定义数据预处理函数preprocess_function,将原始的训练集、验证集和测试集分别转换为模型能接受的输入格式。这里以文本纠错为例,主要处理两个参数,text_column和summary_column分别对应错误句子和正确句子。
def preprocess_function(examples):
inputs = examples[text_column]
targets = examples[summary_column]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
4.3 接着使用map函数对datasets中的所有样本进行预处理。
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
4.4 通过DataCollatorForSeq2Seq,构造一个data collator
from transformers import DataCollatorForSeq2Seq
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
5. 定义评估函数
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
while '' in preds:
idx=preds.index('')
preds[idx]='。'
return preds, labels
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if data_args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
scores = rouge.get_scores(decoded_preds, decoded_labels,avg=True)
for key in scores:
scores[key] = scores[key]['f']*100
result = scores
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
6. 模型训练和预测
6.1 通过Seq2SeqTrainer初始化一个训练器trainer
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics
)
6.2 模型训练
if training_args.do_train:
train_result = trainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
6.3 模型测试
if training_args.do_predict:
if training_args.predict_with_generate:
predictions, labels, metrics = trainer.predict(test_dataset, metric_key_prefix="predict")
test_preds = tokenizer.batch_decode(
predictions, skip_special_tokens=True,
)
test_preds = ["".join(pred.strip().split()) for pred in test_preds]
output_test_preds_file = args.predict_file
with open(output_test_preds_file, "w",encoding='UTF-8') as writer:
writer.write("\n".join(test_preds))
【说明,本人也是初步分析代码,仅用于记录,帮助自己梳理,以上是我自己理解的过程,如果不对,欢迎批评指正!!大家一起加油哟~】