DPO(Direct Preference Optimization)
原文来自于 https://arxiv.org/pdf/2305.18290,
Bradley-Terry (BT)模型,假设人的喜欢遵循下面的公式,给定x,得到
y
1
y_1
y1和
y
2
y_2
y2分别遵循以下关系,其中
r
∗
r^*
r∗是对奖励的估计:
p
∗
(
y
1
≻
y
2
∣
x
)
=
exp
(
r
∗
(
x
,
y
1
)
)
exp
(
r
∗
(
x
,
y
1
)
)
+
exp
(
r
∗
(
x
,
y
2
)
)
p^*(y_1 \succ y_2 \mid x) = \frac{\exp(r^*(x, y_1))}{\exp(r^*(x, y_1)) + \exp(r^*(x, y_2))}
p∗(y1≻y2∣x)=exp(r∗(x,y1))+exp(r∗(x,y2))exp(r∗(x,y1))
除一下得到下面的形式,刚好是可以sigmoid形式
p
∗
(
y
1
≻
y
2
∣
x
)
=
1
1
+
exp
(
r
∗
(
x
,
y
2
)
−
r
∗
(
x
,
y
1
)
)
=
σ
(
r
∗
(
x
,
y
1
)
−
r
∗
(
x
,
y
2
)
)
p^*(y_1 \succ y_2 \mid x) = \frac{1}{1 + \exp(r^*(x, y_2)-r^*(x, y_1))} = \sigma(r^*(x, y_1)-r^*(x, y_2))
p∗(y1≻y2∣x)=1+exp(r∗(x,y2)−r∗(x,y1))1=σ(r∗(x,y1)−r∗(x,y2))
所以重点1来了:有了BT Model的假设,这个preference是一个sigmoid的形式,否则二分类应该是一个CE的形式,这种sigmoid的形式在后面推导最终表达式的时候有一些便利:
最终DPO的loss函数形式是
p
∗
(
y
1
≻
y
2
∣
x
)
=
1
1
+
exp
(
β
log
π
∗
(
y
2
∣
x
)
π
ref
(
y
2
∣
x
)
−
β
log
π
∗
(
y
1
∣
x
)
π
ref
(
y
1
∣
x
)
)
\begin{equation} p^*(y_1 \succ y_2 \mid x) = \frac{1}{1 + \exp \left( \beta \log \frac{\pi^*(y_2 \mid x)}{\pi_{\text{ref}}(y_2 \mid x)} - \beta \log \frac{\pi^*(y_1 \mid x)}{\pi_{\text{ref}}(y_1 \mid x)} \right)} \end{equation}
p∗(y1≻y2∣x)=1+exp(βlogπref(y2∣x)π∗(y2∣x)−βlogπref(y1∣x)π∗(y1∣x))1
这里的
r
∗
(
x
,
y
)
r^*(x,y)
r∗(x,y)实际上是借鉴PPO里面的思路应该表示成以下形式(由于拉格朗日乘数法所以多了一个Z,细节参考原文推导),刚好这个Z(x)由于Bradley-Terry假设就被消掉了,这也是BT Model的重点2,所以得到了上面公式(1)作为DPO的loss函数
r
∗
(
x
,
y
)
=
β
log
π
∗
(
y
∣
x
)
π
ref
(
y
∣
x
)
+
β
log
Z
(
x
)
r^*(x, y) = \beta \log \frac{\pi^*(y \mid x)}{\pi_{\text{ref}}(y \mid x)} + \beta \log Z(x)
r∗(x,y)=βlogπref(y∣x)π∗(y∣x)+βlogZ(x)
代码
import torch
import torch.nn.functional as F
def dpo_loss(policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
beta: float = 0.1,
label_smoothing: float = 0.0) -> torch.Tensor:
"""
DPO损失函数实现
参数说明:
policy_chosen_logps: 策略模型对优选回答的log概率 (batch_size,)
policy_rejected_logps: 策略模型对非优选回答的log概率 (batch_size,)
ref_chosen_logps: 参考模型对优选回答的log概率 (batch_size,)
ref_rejected_logps: 参考模型对非优选回答的log概率 (batch_size,)
beta: 温度参数,控制KL惩罚强度 (默认0.1)
label_smoothing: 标签平滑系数 (默认0.0)
"""
# 计算对数几率差
policy_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = ref_chosen_logps - ref_rejected_logps
# 核心计算逻辑
logits = policy_logratios - ref_logratios
# 带标签平滑的损失计算
losses = - (
F.logsigmoid(beta * logits) * (1 - label_smoothing) +
F.logsigmoid(-beta * logits) * label_smoothing
)
# 计算奖励值(可选)
chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps).detach()
rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps).detach()
return losses.mean(), chosen_rewards, rejected_rewards
在DPO损失中使用F.logsigmoid
而非直接计算sigmoid,主要基于以下关键原因:
1. 数值稳定性(核心考量)
当输入值非常大(正或负)时:
-
直接计算
sigmoid(x)
会导致:- x > 25时:sigmoid(x)=1.0 → log(1.0)=0(有效计算)
- x < -25时:sigmoid(x)=0.0 → log(0)=-inf(梯度爆炸)
-
F.logsigmoid
内部实现:
def log_sigmoid(x):
return -torch.nn.functional.softplus(-x) # 使用softplus的负值
这种实现方式避免了直接计算指数,保证:
- x→+∞时:log_sigmoid(x)→0(稳定)
- x→-∞时:log_sigmoid(x)→-x(线性衰减而非无穷)
2. 梯度保持能力
对比两种计算方式的梯度曲线:
计算方式 | 梯度表达式 | 梯度特性 |
---|---|---|
log(sigmoid(x)) | 1 - sigmoid(x) | 梯度范围始终在(0,1) |
F.logsigmoid(x) | sigmoid(-x) | 自动实现梯度裁剪效果 |
实际反向传播时:
# 常规计算路径
x → exp(x) → 1/(1+exp(-x)) → log(...)
# 数值敏感区:x>20或x<-20时梯度消失
# LogSigmoid专用实现
x → -softplus(-x)
# 梯度计算:1 / (1 + exp(x)) → 始终稳定
IPO(Identity Preference Optimization)
将损失函数替换为(logits - 1/(2*beta))**2,主要是可以减少方差
KTO(Kahneman-Tversky Optimization)
KTO简单来说就是average来做reference point,上面DPO每次都是win和loss这样一对pair来比,KTO改成了从average里面取。这样就不再需要pair wise数据了,只需要point wise数据。但上面那个Z(x)姑且假设还能消掉。
作者调研了RL几个loss function,符合KT理论特征,发现人就是收益边际效用递减+损失厌恶,几种RL的loss都是下面图里的趋势。下面图只是画出了log的大概形状,和x轴和y轴的交点并不完全准确
DiffusionDPO
来自于 https://arxiv.org/pdf/2311.12908,问题是DPO是怎么加的呢?有下面几个点比较关键
Expectation to remove redundant predictions
因为stable diffusion有很多中间状态,解决方案是求个均值,下面公式里c是用户输入的prompt
r
(
c
,
x
0
)
=
E
p
θ
(
x
1
:
T
∣
x
0
,
c
)
[
R
(
c
,
x
0
:
T
)
]
r(c, x_0) = \mathbb{E}_{p_\theta(x_{1:T} \mid x_0, c)} \left[ R(c, x_{0:T}) \right]
r(c,x0)=Epθ(x1:T∣x0,c)[R(c,x0:T)]
Jensen’s inequality
实际上就是通过Jensen不等式,把expectation取出来
Estimate p with q,加噪声时候是q,去噪声时候是p
最终得到的loss函数形式如下,含义也比较直观,尽可能接近winning cases,原理losing cases