【Hugging Face】transformers 库中的 TrainingArguments: 训练参数管理类,控制 Trainer 的训练行为

Hugging Face transformers 库中的 TrainingArguments

TrainingArguments 是 Hugging Face transformers 库中的一个 训练参数管理类,用于 控制 Trainer 的训练行为,如 批量大小、学习率、日志记录、保存策略、多 GPU 训练、混合精度 等。


1. 为什么使用 TrainingArguments

在 PyTorch 训练中,我们通常需要手动设置训练参数,如:

learning_rate = 3e-5
batch_size = 16
num_epochs = 3
save_steps = 500

TrainingArguments 提供了 统一的接口,让用户可以方便地管理所有训练参数,并直接传递给 Trainer,例如:

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=16,
    num_train_epochs=3,
    save_steps=500,
    evaluation_strategy="epoch",
)

这样可以 避免手动管理参数,提高代码可读性


2. TrainingArguments 的基本用法

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",  # 训练结果保存路径
    evaluation_strategy="epoch",  # 评估策略
    save_strategy="epoch",  # 保存策略
    per_device_train_batch_size=8,  # 训练批量大小
    per_device_eval_batch_size=8,  # 评估批量大小
    num_train_epochs=3,  # 训练轮数
    logging_dir="./logs",  # 日志保存路径
    logging_steps=10,  # 多少步记录一次日志
    save_total_limit=2,  # 最多保留多少个模型检查点
)

然后,将 training_args 传入 Trainer

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

3. TrainingArguments 关键参数

3.1. 训练参数

参数作用默认值
output_dir训练结果保存路径
num_train_epochs训练轮数3
per_device_train_batch_size每个 GPU/CPU 的训练批量大小8
per_device_eval_batch_size每个 GPU/CPU 的评估批量大小8
learning_rate初始学习率5e-5
weight_decay权重衰减0.0
warmup_steps预热步数0
adam_beta1Adam 优化器 β1 参数0.9
adam_beta2Adam 优化器 β2 参数0.999

示例:

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=5,
    per_device_train_batch_size=16,
    learning_rate=3e-5,
    weight_decay=0.01,
    warmup_steps=500,
)

3.2. 评估与日志

参数作用默认值
evaluation_strategy评估策略 ("no", "epoch", "steps")"no"
eval_steps多少步评估一次500
logging_dir日志路径"runs/"
logging_steps多少步记录一次日志500
save_strategy模型保存策略 ("no", "epoch", "steps")"steps"
save_steps多少步保存一次模型500
save_total_limit最多保存多少个模型无限制

示例:

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
)

3.3. 多 GPU 和混合精度

参数作用默认值
fp16是否使用混合精度(减少显存占用)False
fp16_full_eval评估时是否使用混合精度False
gradient_accumulation_steps梯度累积步数1
ddp_find_unused_parameters适用于 torch.nn.DataParallelNone

示例:

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,  # 梯度累积,等效于 batch_size=16*4
    fp16=True,  # 启用混合精度
)

3.4. 训练控制

参数作用默认值
disable_tqdm是否禁用进度条False
max_steps最大训练步数-1(不限制)
report_to记录日志的工具 ("wandb", "tensorboard", "none")"none"

示例:

training_args = TrainingArguments(
    output_dir="./results",
    max_steps=10000,  # 最多训练 10000 步
    disable_tqdm=True,  # 禁用进度条
    report_to="tensorboard",  # 记录日志到 TensorBoard
)

4. TrainingArgumentsTrainer 结合

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

5. 多 GPU 训练

如果有多个 GPU,可以运行:

accelerate launch train.py

或者在 TrainingArguments 里启用:

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    fp16=True,  # 启用混合精度
)

6. TrainingArguments vs PyTorch 训练参数

功能TrainingArgumentsPyTorch
设置学习率learning_rate=3e-5torch.optim.Adam(model.parameters(), lr=3e-5)
训练批量大小per_device_train_batch_size=8DataLoader(dataset, batch_size=8)
训练轮数num_train_epochs=3for epoch in range(3):
评估策略evaluation_strategy="epoch"需手动实现
混合精度训练fp16=True需手动实现

使用 TrainingArguments 可以 大大简化 PyTorch 训练代码


7. TrainingArguments 适用于哪些任务?

任务适用情况
文本分类
机器翻译
文本摘要
语音识别
计算机视觉适用于 ViT、CLIP 等

8. 总结

  1. TrainingArguments 用于管理 Trainer 的训练参数,让 Hugging Face 训练更简单。
  2. 支持批量大小、学习率、日志记录、评估策略、多 GPU 训练
  3. 减少手动管理超参数的工作量,可以通过 Trainer 直接调用。
  4. 适用于 NLP、CV、语音等任务,可以高效微调 Transformer 模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值