Direct Preference Optimization (DPO) 简介与流程解析
Direct Preference Optimization (DPO) 是一种基于人类偏好的强化学习优化方法,用于训练语言模型,使其更好地满足用户需求或偏好。本文将详细介绍 DPO 的核心思想、优化流程,并结合代码示例分析其具体实现方式。
1. 什么是 DPO?
DPO 是一种优化策略,它通过最小化损失函数来直接学习人类偏好数据,而无需依赖复杂的强化学习框架,如 Proximal Policy Optimization (PPO)。
DPO 的关键思想:
- 使用人类偏好数据集来构建训练样本,包含“偏好”和“非偏好”对比项。
- 通过直接优化偏好损失函数,引导策略模型学习偏好分布,同时控制 KL 散度,避免模型过度偏离参考分布。
DPO 的优势:
- 不依赖环境交互: DPO 在离线数据上训练,无需与环境实时交互,降低了训练复杂度。
- 直接利用人类偏好: 简化了强化学习的步骤,将人类偏好直接应用于优化过程。
- 更稳定的训练过程: 通过 KL 散度约束保持策略分布平稳,减少训练不稳定性。
2. DPO 的主要流程
Step 1: 数据集准备
假设我们有一个人类偏好标注的数据集 ( D D D):
D = { ( x i , y w i , y l i ) } i = 1 N D = \{(x_i, y_w^i, y_l^i)\}_{i=1}^N D={(xi,ywi,yli)}i=1N
其中:
- ( x i x_i xi):提示(prompt)或输入样本。
- ( y w i y_w^i ywi):偏好的输出(preferred completion)。
- ( y l i y_l^i yli):不偏好的输出(less preferred completion)。
示例数据:
Prompt | Preferred Output (( y w y_w yw)) | Less Preferred Output (( y l y_l yl)) |
---|---|---|
“Summarize this text.” | “Key points are…” | “The main topic is…” |
“Translate to French.” | “Bonjour tout le monde.” | “Salut à tous.” |
Step 2: 初始化参考模型 (( π r e f \pi_{ref} πref))
- 如果已有监督微调模型 (( π S F T \pi_{SFT} πSFT)),则直接使用它作为参考模型:
π r e f = π S F T \pi_{ref} = \pi_{SFT} πref=πSFT
- 如果没有 ( π S F T \pi_{SFT} πSFT),则通过最大似然估计 (MLE) 预训练一个参考模型:
π r e f = arg max π E ( x , y w ) ∼ D [ log π ( y w ∣ x ) ] \pi_{ref} = \arg\max_{\pi} \mathbb{E}_{(x, y_w) \sim D} [\log \pi(y_w | x)] πref=argπmaxE(x,yw)∼D[logπ(yw∣x)]
解释:
参考模型用于提供基准分布,帮助控制训练过程中的分布偏移,确保策略模型不会偏离参考分布太远。
Step 3: 损失函数定义与优化
DPO 的损失函数旨在优化策略模型 ( π θ \pi_\theta πθ) 相对于参考模型 ( π r e f \pi_{ref} πref) 的表现,同时引入 KL 散度惩罚项控制更新幅度:
损失函数公式:
L D P O ( θ ) = − log σ ( β [ log π θ ( y w ∣ x ) π θ ( y l ∣ x ) − log π r e f ( y w ∣ x ) π r e f ( y l ∣ x ) ] ) L_{DPO}(\theta) = - \log \sigma\left(\beta \left[\log \frac{\pi_\theta(y_w | x)}{\pi_\theta(y_l | x)} - \log \frac{\pi_{ref}(y_w | x)}{\pi_{ref}(y_l | x)}\right]\right) LDPO(θ)=−logσ(β[logπθ(yl∣x)πθ(yw∣x)−logπref(yl∣x)πref(yw∣x)])
其中:
- ( β \beta β):控制 KL 散度强度的超参数。
- ( σ ( ⋅ ) \sigma(\cdot) σ(⋅)):Sigmoid 函数,用于平滑处理。
- ( π θ \pi_\theta πθ):当前策略模型的输出概率。
- ( π r e f \pi_{ref} πref):参考模型的输出概率。
直观解释:
- 比较策略模型与参考模型在偏好和非偏好输出上的概率比率,优化模型使得偏好输出的概率更大,同时限制变化幅度。
代码实现:
摘录自原paper: https://arxiv.org/pdf/2305.18290
import torch.nn.functional as F
def dpo_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, beta):
"""
计算 DPO 损失函数
Args:
pi_logps: 策略模型的对数概率, shape (B,)
ref_logps: 参考模型的对数概率, shape (B,)
yw_idxs: 偏好输出的索引, shape (T,)
yl_idxs: 不偏好输出的索引, shape (T,)
beta: KL 惩罚项超参数
Returns:
losses: 损失值
rewards: 奖励信号
"""
# 取出偏好与非偏好对应的对数概率
pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs]
ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs]
# 计算策略与参考模型的概率比率
pi_logratios = pi_yw_logps - pi_yl_logps
ref_logratios = ref_yw_logps - ref_yl_logps
# 损失函数
losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios))
# 奖励信号
rewards = beta * (pi_logps - ref_logps).detach()
return losses, rewards
Step 4: 模型训练设置
摘录自原paper: https://arxiv.org/pdf/2305.18290
-
优化器与超参数设置:
- 使用 RMSprop 优化器。
- 学习率:( 1 × 1 0 − 6 1 \times 10^{-6} 1×10−6)。
- 批量大小:64。
- 学习率线性预热 150 步,从 0 增加到 ( 1 × 1 0 − 6 1 \times 10^{-6} 1×10−6)。
-
特定任务调整:
- 摘要任务:将 ( β = 0.5 \beta = 0.5 β=0.5)。
-
训练过程:
- 从参考模型加载初始参数。
- 根据数据集批量计算损失和奖励。
- 逐步更新策略模型,平衡偏好学习和 KL 散度控制。
3. DPO 的实际应用与优势分析
应用场景:
- 对话生成:
- 使用人类反馈优化对话质量,使模型更符合用户期望。
- 摘要生成:
- 根据偏好数据训练模型生成更简洁或更详细的摘要。
- 翻译任务:
- 微调翻译模型使其符合语言习惯和文化背景。
与传统方法对比:
方法 | 特点 | 优势 |
---|---|---|
PPO | 在线交互式优化,依赖环境反馈 | 强鲁棒性,但实现复杂,训练成本高 |
DPO | 离线偏好优化,无需环境交互,直接优化损失函数 | 简单易实现,训练更稳定,适合大规模预训练模型微调 |
4. 总结
DPO 提供了一种高效稳定的方式,将人类偏好直接融入语言模型训练过程。通过对比参考模型与策略模型的输出概率,并结合 KL 散度控制,DPO 避免了传统强化学习中的不稳定性,尤其适合大规模预训练模型的偏好微调任务。
关键点回顾:
- 准备偏好数据集,标注“偏好”和“非偏好”输出。
- 初始化参考模型,确保稳定的分布基础。
- 计算损失函数,通过 KL 散度控制更新幅度。
- 利用超参数调优,平衡偏好优化与探索能力。
DPO 简化了强化学习流程,为语言模型的微调提供了一种高效且实用的解决方案。
Introduction to Direct Preference Optimization (DPO): Process and Implementation
Direct Preference Optimization (DPO) is a reinforcement learning-based optimization method that trains language models to better align with user preferences. This blog provides a detailed explanation of DPO’s core concepts, step-by-step pipeline, and practical implementation using code examples.
1. What is DPO?
DPO is an optimization strategy that directly learns from human preference data by minimizing a preference-based loss function, eliminating the need for complex reinforcement learning frameworks like Proximal Policy Optimization (PPO).
Key Ideas of DPO:
- Utilize human preference datasets containing pairs of “preferred” and “less preferred” completions.
- Optimize a preference-based loss function to guide the policy model while controlling KL divergence to prevent distributional shifts.
Advantages of DPO:
- No Environment Interaction Required: DPO trains offline, reducing complexity compared to traditional reinforcement learning.
- Direct Use of Human Preferences: Simplifies optimization by directly leveraging preference data.
- Stable Training Process: KL divergence constraints ensure smooth updates and prevent instability.
2. Main Workflow of DPO
Step 1: Prepare the Dataset
Assume we have a human-annotated preference dataset ( D D D):
D = { ( x i , y w i , y l i ) } i = 1 N D = \{(x_i, y_w^i, y_l^i)\}_{i=1}^N D={(xi,ywi,yli)}i=1N
Where:
- ( x i x_i xi): Prompt or input example.
- ( y w i y_w^i ywi): Preferred output (completion).
- ( y l i y_l^i yli): Less preferred output.
Sample Data:
Prompt | Preferred Output (( y w y_w yw)) | Less Preferred Output (( y l y_l yl)) |
---|---|---|
“Summarize this text.” | “Key points are…” | “The main topic is…” |
“Translate to French.” | “Bonjour tout le monde.” | “Salut à tous.” |
Step 2: Initialize Reference Model (( π r e f \pi_{ref} πref))
- If a supervised fine-tuned model (( π S F T \pi_{SFT} πSFT)) is available, use it as the reference model:
π r e f = π S F T \pi_{ref} = \pi_{SFT} πref=πSFT
- If ( π S F T \pi_{SFT} πSFT) is not available, pre-train a reference model via Maximum Likelihood Estimation (MLE):
π r e f = arg max π E ( x , y w ) ∼ D [ log π ( y w ∣ x ) ] \pi_{ref} = \arg\max_{\pi} \mathbb{E}_{(x, y_w) \sim D} [\log \pi(y_w | x)] πref=argπmaxE(x,yw)∼D[logπ(yw∣x)]
Explanation:
The reference model provides a baseline distribution that prevents distributional shifts during training, ensuring stability.
Step 3: Define and Optimize the Loss Function
DPO’s loss function is designed to optimize the policy model ( π θ \pi_\theta πθ) relative to the reference model ( π r e f \pi_{ref} πref) while introducing a KL-divergence penalty to constrain updates.
Loss Function Formula:
L D P O ( θ ) = − log σ ( β [ log π θ ( y w ∣ x ) π θ ( y l ∣ x ) − log π r e f ( y w ∣ x ) π r e f ( y l ∣ x ) ] ) L_{DPO}(\theta) = - \log \sigma\left(\beta \left[\log \frac{\pi_\theta(y_w | x)}{\pi_\theta(y_l | x)} - \log \frac{\pi_{ref}(y_w | x)}{\pi_{ref}(y_l | x)}\right]\right) LDPO(θ)=−logσ(β[logπθ(yl∣x)πθ(yw∣x)−logπref(yl∣x)πref(yw∣x)])
Where:
- ( β \beta β): Temperature parameter controlling KL penalty strength.
- ( σ ( ⋅ ) \sigma(\cdot) σ(⋅)): Sigmoid function for smoothing.
- ( π θ \pi_\theta πθ): Current policy model probabilities.
- ( π r e f \pi_{ref} πref): Reference model probabilities.
Intuitive Explanation:
- Compare the probability ratios between preferred and less preferred outputs.
- Optimize the model to increase probabilities for preferred outputs while limiting deviation from the reference model.
Code Implementation:
import torch.nn.functional as F
def dpo_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, beta):
"""
Compute DPO loss.
Args:
pi_logps: Log-probabilities from the policy model, shape (B,)
ref_logps: Log-probabilities from the reference model, shape (B,)
yw_idxs: Indices of preferred completions, shape (T,)
yl_idxs: Indices of less preferred completions, shape (T,)
beta: KL penalty strength
Returns:
losses: Loss values
rewards: Reward signals
"""
# Extract log-probs for preferred and less preferred completions
pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs]
ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs]
# Compute log-ratios
pi_logratios = pi_yw_logps - pi_yl_logps
ref_logratios = ref_yw_logps - ref_yl_logps
# Compute loss
losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios))
# Compute rewards
rewards = beta * (pi_logps - ref_logps).detach()
return losses, rewards
Step 4: Model Training Settings
-
Optimizer and Hyperparameters:
- Optimizer: RMSprop.
- Learning Rate: ( 1 × 1 0 − 6 1 \times 10^{-6} 1×10−6).
- Batch Size: 64.
- Linear warmup: Gradually increase learning rate from 0 to ( 1 × 1 0 − 6 1 \times 10^{-6} 1×10−6) over 150 steps.
-
Task-Specific Adjustments:
- For summarization: Set ( β = 0.5 \beta = 0.5 β=0.5).
-
Training Process:
- Load initial parameters from the reference model.
- Compute losses and rewards in batches using the preference dataset.
- Update the policy model step-by-step, balancing preference optimization and KL divergence constraints.
3. Applications and Advantages
Use Cases:
- Dialogue Generation:
- Optimize conversational quality to better match user expectations.
- Summarization Tasks:
- Train models to generate concise or detailed summaries based on preferences.
- Translation Models:
- Fine-tune translations to match linguistic and cultural nuances.
Comparison with Traditional Methods:
Method | Features | Advantages |
---|---|---|
PPO | Online optimization, requires environment feedback | Robust but complex implementation and expensive |
DPO | Offline preference optimization, no interaction needed | Simple implementation, more stable training |
4. Conclusion
DPO provides an efficient and stable method for integrating human preferences into language model training. By comparing probabilities between preferred and less preferred outputs and applying KL-divergence constraints, DPO avoids instability often seen in traditional reinforcement learning, making it ideal for fine-tuning large language models.
Key Takeaways:
- Prepare a preference dataset labeled with “preferred” and “less preferred” outputs.
- Initialize a reference model for stable training.
- Optimize the preference loss function while controlling distribution shifts.
- Tune hyperparameters to balance preference learning and exploration.
DPO simplifies reinforcement learning and offers a practical solution for preference-based fine-tuning of pre-trained language models.
参考
[1] DPO(Direct Preference Optimization)算法解释:中英双语
后记
2024年12月26日21点19分于上海,在GPT4o大模型辅助下完成。