Hugging Face Trainer API 进阶指南:自定义训练循环全解析!


🔍 让我们一起探索 Hugging Face 的 Trainer API 自定义训练循环!

📌 1. 预备工作

在开始之前,先安装必要的 Python 库:

pip install transformers datasets

此外,你还需要安装适用于你环境的 PyTorch 版本,请确保安装正确的版本。


🛠 2. 使用 Trainer API 进行自定义训练

如果你曾经微调过 Transformer 模型,你可能会好奇其底层机制,以及如何自定义调整训练流程

Trainer API 允许你直接使用 Hugging Face 提供的默认训练流程,但如果你的需求较为特殊,你可以自定义训练循环以满足特定任务需求。

📌 2.1 加载预训练模型、分词器和数据集

在本示例中,我们将使用 BERT 进行文本分类任务,并加载 IMDb 数据集。

from transformers import BertForSequenceClassification, BertTokenizer
from datasets import load_dataset

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

dataset = load_dataset('imdb')

📌 2.2 数据预处理

我们对文本进行分词处理,并只选取部分数据以加快训练过程。

def preprocess_function(examples):
    return tokenizer(examples['text'], truncation=True, padding=True)

tokenized_datasets = dataset.map(preprocess_function, batched=True)

small_train_dataset = tokenized_datasets['train'].shuffle(seed=42).select(range(100))
small_eval_dataset = tokenized_datasets['test'].shuffle(seed=42).select(range(50))

📌 2.3 训练参数设置

我们将训练 1 个 epoch,并设置较大的 batch size 以加快训练。

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    logging_dir='./logs',
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=1,
    logging_steps=10,
    save_total_limit=2,
)

🔥 3. 自定义 Trainer 训练循环

📌 3.1 继承 Trainer 类,创建自定义训练器

在 Hugging Face 默认的 Trainer 基础上,我们可以自定义优化器、调度器和训练逻辑

from torch.optim import AdamW
from transformers import get_scheduler
from transformers import Trainer

class CustomTrainer(Trainer):
    def create_optimizer_and_scheduler(self, num_training_steps):
        if self.optimizer is None:
            self.optimizer = AdamW(self.model.parameters(), lr=self.args.learning_rate)
        if self.lr_scheduler is None:
            self.lr_scheduler = get_scheduler(
                name="linear",
                optimizer=self.optimizer,
                num_warmup_steps=0,
                num_training_steps=num_training_steps,
            )

    def train(self, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None, **kwargs):
        # 初始化优化器和调度器
        num_training_steps = len(self.get_train_dataloader()) * self.args.num_train_epochs
        self.create_optimizer_and_scheduler(num_training_steps)

        model = self.model
        for epoch in range(int(self.args.num_train_epochs)):
            print(f"Starting epoch {epoch + 1}")
      
            for step, batch in enumerate(self.get_train_dataloader()):   
                outputs = model(**batch)
                loss = outputs.loss
                loss.backward()
               
                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()

                if step % self.args.logging_steps == 0:
                    print(f"Step {step}: Loss = {loss.item()}")

        print("Training is Done")

📌 3.2 代码解析(我们自定义了哪些内容?)

使用 AdamW 作为优化器:自适应学习率的优化器,适用于 Transformer 训练。
设置线性学习率调度器:随着训练进行,逐步降低学习率,提高模型收敛效果。
自定义训练循环:每个 batch 计算损失、反向传播、更新梯度,并每隔 logging_steps 进行日志记录。


🚀 4. 训练和评估模型

📌 4.1 实例化自定义 Trainer 并开始训练

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
)

trainer.train()

📌 4.2 评估模型

evaluation_results = trainer.evaluate()
print(evaluation_results)

📌 4.3 训练结果示例

{
    'eval_loss': 0.1545, 
    'eval_model_preparation_time': 0.0038, 
    'eval_runtime': 765.59, 
    'eval_samples_per_second': 32.65, 
    'eval_steps_per_second': 1.02
}

🎯 结论

通过自定义 Trainer,你可以灵活调整训练逻辑,满足特定任务需求。例如:

修改优化器(如 AdamW、SGD)
更改学习率调度器(如线性衰减、余弦退火)
调整训练日志输出方式
添加更多自定义训练逻辑(如梯度累积、混合精度训练等)

如果你的 Transformer 训练需求超出了 Hugging Face 默认 Trainer 的功能,那么掌握 Trainer API 的自定义训练循环将极大提升你的开发效率!🚀


💡 你对 Hugging Face 的 Trainer API 还有哪些疑问?欢迎留言讨论!👇

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值