Pytorch框架下的transformers的使用

huggingface团队在pytorch框架下开发了transformers工具包:https://github.com/huggingface/transformers,工具包实现了大量基于transformer的模型,如albert,bert,roberta等。工具包的代码结构如图所示:

transformers工具包的包结构

其中比较重要的是src/transformers以及example这两个文件夹。其中,src/transformers文件夹下是各类transformer模型的实现代码;而examples下主要是各类下游任务的微调代码。我们以文本分类任务为例来说明微调过程具体是如何实现的,在官方的例子中,使用GLUE数据集。

一、run_glue.sh文件解析

按照官方文档的指引,首先需要构建用于启动微调程序的脚本文件,脚本为微调程序提供参数。

export GLUE_DIR=/path/to/glue
export TASK_NAME=MRPC

python ./examples/text-classification/run_glue.py \
    --model_name_or_path bert-base-uncased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --data_dir $GLUE_DIR/$TASK_NAME \
    --max_seq_length 128 \
    --per_device_eval_batch_size=8   \
    --per_device_train_batch_size=8   \
    --learning_rate 2e-5 \
    --num_train_epochs 3.0 \
    --output_dir /tmp/$TASK_NAME/

其中几个主要参数的意义如下:

  • model_name_or_path:用于指定进行微调的预训练模型。参数可以是模型名称,在第一次执行微调程序时,会自动下载对应的模型;参数也可以是模型路径,此时需要提前下载对应的模型到设定的路径中。
  • task_name:用于指定具体的下游任务,微调程序需要根据任务名称选择相应的processor以实现数据加载。
  • data_dir:用于指定微调数据的存储路径。
  • output_dir:用于指定微调好的模型的存放路径

二、run_glue.py文件解析

启动脚本会调用run_glue.py文件来执行微调程序。程序主要有三部分功能:加载模型,加载数据,进行微调(训练,验证,预测)。

1、加载预训练模型

(1)加载用于构建模型以及用于微调过程的参数

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

其中,类ModelArguments中包含的是关于模型的属性,如model_name,config_name,tokenizer_name等,类在run.py文件中定义;类DataTrainingArguments中包含的是关于微调数据的属性,如task_name,data_dir等,类在transformers/data/datasets/glue.py文件中定义;TrainingArguments中包含的是关于微调过程的参数,如batch_size,learning_rate等参数,类在transformers/training_args.py中定义。

(2)生成model,config,tokenizer

其中,config用于加载配置信息,model根据config加载模型,tokenize用于在加载数据时提供编码信息。

config = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    num_labels=num_labels,
    finetuning_task=data_args.task_name,
    cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
)
model = AutoModelForSequenceClassification.from_pretrained(
    model_args.model_name_or_path,
    from_tf=bool(".ckpt" in model_args.model_name_or_path),
    config=config,
    cache_dir=model_args.cache_dir,
)

2、加载数据

需要使用GlueDataset类构建数据,类定义在transformers/data/datasets/glue.py中,是对Dataset类的继承。

train_dataset = (
    GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
)
eval_dataset = (
    GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
    if training_args.do_eval
    else None
)
test_dataset = (
    GlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
    if training_args.do_predict
    else None
)

在GlueDataset类中,需要利用glue_processors类来加载数据内容。glue_processors类定义在transformers/data/processors/glue.py中。

self.processor = glue_processors[args.task_name]()

if mode == Split.dev:
    examples = self.processor.get_dev_examples(args.data_dir)
elif mode == Split.test:
    examples = self.processor.get_test_examples(args.data_dir)
else:
    examples = self.processor.get_train_examples(args.data_dir)

3、微调(训练,验证,预测)

(1)构建训练器

训练器Trainer类:主要用于指定使用的模型,数据,微调过程所用参数的信息。类中包含用于训练,验证,预测的方法:trainer.train(train_dataset),trainer.evaluate(eval_dataset),trainer.predicate(test_dataset)。

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=build_compute_metrics_fn(data_args.task_name),
)

(2)进行微调(训练,验证,预测)

三、如何定义自己的微调方法

有时候,我们的数据可能与官方所用的数据形式不同,这时候需要对方法进行重写以定义自己的微调方法,重写的内容主要包括:

  1. 重写dataset类
  2. 重写processor类

所有用到的参数都以属性的形式存在于ModelArguments,DataTrainingArguments,TrainingArguments这三个类中,若要改变某个参数,只需要在启动脚本中设置即可。

【1】项目代码完整且功能都验证ok,确保稳定可靠运行后才上传。欢迎下载使用!在使用过程中,如有问题或建议,请及时私信沟通,帮助解答。 【2】项目主要针对各个计算机相关专业,包括计科、信息安全、数据科学与大数据技术、人工智能、通信、物联网等领域的在校学生、专业教师或企业员工使用。 【3】项目具有较高的学习借鉴价值,不仅适用于小白学习入门进阶。也可作为毕设项目、课程设计、大作业、初期项目立项演示等。 【4】如果基础还行,或热爱钻研,可基于此项目进行二次开发,DIY其他不同功能,欢迎交流学习。 【注意】 项目下载解压后,项目名字和项目路径不要用中文,否则可能会出现解析不了的错误,建议解压重命名为英文名字后再运行!有问题私信沟通,祝顺利! 基于PyTorch框架构建Transformers模型并应用于机器翻译任务项目python源码+项目使用说明.zip 本项目是一个基于 PyTorch 框架构建 Transformers 模型并应用于翻译任务的项目,其中附带了详细的文档介绍 Transformers 模型在训练和推理中数据是如何变化的。 环境 * python 3.10 * torch 2.0.0 数据集 数据集采用的是 AI Challenger 的数据集,其中包含中英文个 1000 万条日常生活常用的短句。 下面是一些训练集的例子。 ```text A pair of red - crowned cranes have staked out their nesting territory A pair of crows had come to nest on our roof as if they had come for Lhamo. A couple of boys driving around in daddy's car. A pair of nines? You pushed in with a pair of nines? Fighting two against one is never ideal, ``` ```text 一对丹顶鹤正监视着它们的筑巢领地 一对乌鸦飞到我们屋顶上的巢里,它们好像专门为拉木而来的。 一对乖乖仔开着老爸的车子。 一对九?一对九你就全下注了? 一对二总不是好事, ``` 使用 * process_data.py 数据集预处理 * train.py 训练模型 * translate.py 使用模型进行英译中翻译 训练结果 由于硬件条件有限,只训练了1个epoch,但由于数据集很大,训练1个epoch就已经由一定的效果了,且loss还有下降的趋势。 下面是测试集的一些例子。 ```text Find a safety chain or something to keep these lights in place. So that no parent has to go through what I've known. I have to go to the date, learn to dance. Definitely. Now. Is when someone we've trusted makes the choice for us. Okay. Well, I guess there's not much to do about it right now then. I respect that, and I will protect it at all cost. ``` ```text 找到安全链或者保存这些灯。 所以我知道的不是父母。 我要去约会,学会跳舞。当然。现在。 是否有人信任我们。 好吧。那么,我想现在没什么可做的了。 我尊重这一点,我会保护它的。 ```
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值