DPO: Direct Preference Optimization 介绍

DPO 是 RLHF 的屌丝版本,RLHF 需要加载 4 个模型(2个推理,2个训练),DPO 只需要加载 2 个模型(1个推理,一个训练)。

RLHF:

DPO:

 

DPO 原理

DPO 的本质是监督对比学习:通过对每条prompt提供两条不同的answer,并给出这两个answer的偏好偏序,让模型输出更接近good answer,同时更远离 bad answer。

这个过程中并不强制要求上述两者同时满足,只要接近good answer的程度大于bad answer就是有效的训练,比如与good answer远离了,但是与bad answer远离的更多也是有效的。

DPO loss

 

σ :sigmoid函数

β :超参数,一般在0.1 - 0.5之间

y_w :某条偏好数据中好的response,w就是win的意思

y_l :某条偏好数据中差的response,l就是loss的意思,所以偏好数据也叫comparision data

\pi_\theta(y_w|x) :给定输入x, 当前policy model生成好的response的累积概率(每个tokne的概率求和,具体看代码)

\pi_{ref}(y_l|x) :给定输入x, 原始模型(reference model)生成坏的response的累积概率

开始训练时,reference model和policy model都是同一个模型,只不过在训练过程中reference model不会更新权重。

简化形式:忽略 logsigmoid 并取对数

由于最初loss前面是有个负号的,所以优化目标是让本简化公式最大,即希望左半部分和右半部分的margin越大越好,左半部分的含义是good response相较于没训练之前的累积概率差值,右半部分代表bad response相较于没训练之前的累计概率差值,如果这个差值,即margin变大了。

 DPO 数据集

可以由prompt 模板: Human: prompt. Assistant: chosen/rejected 构成如下数据:Anthropic/hh-rlhf dataset

 DPO trainer 期望数据集具有非常特定的格式。 给定两个句子时,模型将被训练为直接优化偏好:那一个句子最相关。

Huagging Face DPO Trainer

与 PPO 期望 AutoModelForCausalLMWithValueHead 作为值函数相比,DPO 训练器期望 AutoModelForCausalLM 模型。 

 dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

Loss 选择:

  • RSO 作者建议在 SLiC 论文中的归一化似然上使用 hinge损失。 DPOTrainer 可以通过 loss_type="hinge" 参数切换到此损失,这种情况下的 beta 是margin的倒数。
  • IPO 作者对 DPO 算法提供了更深入的理论理解,并识别了过度拟合的问题,并提出了一种替代损失,可以通过训练器的 loss_type="ipo" 参数来使用。
  • cDPO 是对 DPO 损失的调整,其中我们假设偏好标签有一定的噪声,可以通过 label_smoothing 参数(0 到 0.5 之间)传递到 DPOTrainer,然后使用保守的 DPO 损失。 使用 loss_type="cdpo" 参数给训练器来使用它。
  • KTO 损失的导出是为了直接最大化 LLM 代的效用,而不是偏好的对数似然。 因此,数据集不一定是偏好,而是期望的完成与不期望的完成。 对于 DPOTrainer 所需的配对偏好数据,请使用训练器的 loss_type="kto_pair" 参数来利用此损失,而对于所需和不需要的数据的更一般情况,请使用尚未实现的 KTOTrainer。

简单实例

#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaConfig
from copy import deepcopy

torch.manual_seed(0)
if __name__ == "__main__":
    # 超参数
    beta = 0.1
    # 加载模型
    policy_model = LlamaForCausalLM(config=LlamaConfig(vocab_size=1000, num_hidden_layers=1, hidden_size=128))
    reference_model = deepcopy(policy_model)

    # data
    prompt_ids = [1, 2, 3, 4, 5, 6]
    good_response_ids = [7, 8, 9, 10]
    # 对loss稍加修改可以应对一个good和多个bad的情况
    bad_response_ids_list = [[1, 2, 3, 0], [4, 5, 6, 0]]

    # 转换成模型输入 [3, 10]
    input_ids = torch.LongTensor(
        [prompt_ids + good_response_ids, *[prompt_ids + bad_response_ids for bad_response_ids in bad_response_ids_list]]
    )
    # labels 提前做个shift [3, 9]
    labels = torch.LongTensor(
        [
            [-100] * len(prompt_ids) + good_response_ids,
            *[[-100] * len(prompt_ids) + bad_response_ids for bad_response_ids in bad_response_ids_list]
        ]
    )[:, 1:]
    loss_mask = (labels != -100)
    labels[labels == -100] = 0
    # 计算 policy model的log prob
    # policy_model(input_ids)["logits"] [3, 10, 1000] 句末的推理结果无效直接忽略
    logits = policy_model(input_ids)["logits"][:, :-1, :]
    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
    all_logps = (per_token_logps * loss_mask).sum(-1)
    # 暂时写死第一个是good response的概率, 三个例子中第一个是 good answer, 后两个是 bad answer
    policy_good_logps, policy_bad_logps = all_logps[:1], all_logps[1:]

    # 计算 reference model的log prob
    with torch.no_grad():
        logits = reference_model(input_ids)["logits"][:, :-1, :]
        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
        all_logps = (per_token_logps * loss_mask).sum(-1)
        # 暂时写死第一个是good response的概率
        reference_good_logps, reference_bad_logps = all_logps[:1], all_logps[1:]

    # 计算loss,会自动进行广播
    logits = (policy_good_logps - reference_good_logps) - (policy_bad_logps - reference_bad_logps)
    loss = -F.logsigmoid(beta * logits).mean()
    print(loss)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值