TRL - Transformer Reinforcement Learning(基于Transformer的强化学习)

TRL - Transformer Reinforcement Learning(基于Transformer的强化学习)

flyfish

地址:https://github.com/huggingface/trl

TRL是一个用于对基础模型进行后训练的全面库,是一个前沿的库,专门用于使用先进的技术(如监督微调(SFT)、近端策略优化(PPO)和直接偏好优化(DPO))对基础模型进行后训练。

利用Accelerate,通过DDP和DeepSpeed等方法,从单个GPU扩展到多节点集群。与PEFT完全集成,通过量化和LoRA/QLoRA,即使在硬件资源有限的情况下也能训练大型模型。集成Unsloth,通过优化的内核加速训练。

通过SFTTrainer、DPOTrainer、RewardTrainer、ORPOTrainer等训练器,轻松访问各种微调方法。使用预定义的模型类(如AutoModelForCausalLMWithValueHead)简化与LLMs的强化学习(RL)

如何使用 TRL(Transformers Reinforcement Learning)库

1. SFTTrainer - 监督微调训练器

  • 用途:用于在自定义数据集上进行监督微调(Supervised Fine-Tuning, SFT)。
  • 使用方法
    1. 加载你的训练数据集,例如 "trl-lib/Capybara"
    2. 配置训练参数,包括输出目录等。
    3. 初始化 SFTTrainer 并传入模型、训练数据集等参数。
    4. 调用 .train() 方法开始训练。
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
    args=training_args,
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
)
trainer.train()

2. RewardTrainer - 奖励模型训练器

  • 用途:用于训练奖励模型,该模型可以评估文本生成的质量,通常用于强化学习框架中。
  • 使用方法
    1. 加载预训练的模型和分词器。
    2. 加载适合于奖励模型的数据集。
    3. 配置训练参数。
    4. 初始化 RewardTrainer 并传入必要组件。
    5. 开始训练。
from trl import RewardConfig, RewardTrainer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", num_labels=1)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
    args=training_args,
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
)
trainer.train()

3. GRPOTrainer - 组相对策略优化训练器

  • 用途:基于组相对策略优化算法,适用于更高效的内存利用场景。
  • 使用方法:与上述类似,但需要定义一个奖励函数。
from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset

def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

dataset = load_dataset("trl-lib/tldr", split="train")
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

4. DPOTrainer - 直接偏好优化训练器

  • 用途:用于根据人类偏好直接优化语言模型。
  • 使用方法:加载模型和分词器,配置训练参数,并初始化 DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()

安装

pip install trl

根据自己的需求定制

git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]

-e 参数:表示“editable”模式安装。这意味着可以在开发过程中直接从源代码目录运行程序,而不需要每次修改后重新安装。这对于开发阶段非常有用,因为它允许在原地编辑代码,并立即看到更改的效果。

. (点):这表示当前目录是需要被安装的Python包或项目。也就是说,pip会在这个目录下寻找setup.py文件(或pyproject.toml等),并根据里面定义的信息进行安装。

[dev]:这部分指定了一个额外的依赖集,通常在setup.py文件中的extras_require字段中定义。[dev]意味着除了基本依赖外,还将安装标记为开发所需的额外依赖项

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二分掌柜的

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值