一、DPO的大模型的技术点
直接偏好优化(Direct Preference Optimization,DPO)是当前将大型语言模型(LLM)与人类偏好对齐的热门方法之一。借助 LoRA 和 QLoRA 等参数高效微调技术,我们可以在单GPU卡上对拥有80 亿参数的模型(如 Llama 3.1 8B 和 Qwen2.5 7B)进行 DPO 训练,当然训练序列可能较短。但如果更大的模型,比如72B,就需要使用多GPU卡。
技术点
举个例子,假设我们想在一台拥有 8 块 H100 GPU(总共 640 GB 显存)的机器上,对一个 700 亿参数的模型进行 DPO 训练。我们需要考虑以下几点:
-
策略模型(Policy Model):我们要训练的模型,占用约 140 GB 的显存。
-
参考模型(Reference Model):DPO 需要一个参考模型,通常与策略模型结构相同,也占用约 140 GB 的显存。
这样,仅模型参数就已经用掉了 280 GB 的显存,大约是总显存的 43.75%。再加上优化器状态,比如使用 AdamW 优化器,每个参数会有两个额外的状态变量。如果这些状态变量以 16 位精度存储,会占用额外的 280 GB 显存。算下来,我们已经用了 560 GB 的显存,只剩下 80 GB 了。这些剩余的显存还要用于存储激活值和梯度。如果不采取特殊的方法,单靠一台机器恐怕无法训练。
二、微调并行化
我们可以使用 PyTorch 的 Fully Sharded Data Parallel(FSDP)技术,配合像 LoRA 和 QLoRA 这样的参数高效微调方法。FDSP类似于DeepSpeed的ZeRO3技术。
FSDP 是一种分布式训练技术,它可以将模型的参数、优化器状态和梯度分片,并分布到多个设备上(比如 GPU)。在前向和反向传播过程中,只有需要的参数片会被加载到内存中,计算完成后就会释放。这大大降低了内存需求。
当然在更大模型的训练的时候,可以使用 DeepSpeed 。DeepSpeed 那样需要大量的内存来存储全精度的模型参数。
在测试中,我两种方法都尝试了。我的环境是两个H100的VM:
我先展示使用accelerate的方法。
上述完整代码见我的repo。
需要指出的是,出于节省显存的目的,我设置的参考模型是通过LoRA adapter加载的,参考模型的基础模型和策略模型是一个。
启动训练:
#accelerate launch --config_file config_fsdp.yaml fsdp+QLoRA.py
三、 DPO 训练输出的各个字段解释
在 DPO 训练中,模型会被提供一组对话,每组包含相同的“提示(prompt)”或“问题(question)”,以及对应的“被选(chosen)”和“被拒(rejected)”回复。模型需要学习区分这些回复,倾向于生成高质量的“被选”回复。
DPO 训练数据解析
训练数据,包括:
-
来源(source):Airoboros
-
被选回复(chosen):包含多轮对话
-
被拒回复(rejected):包含多轮对话
-
提示(prompt):一段描述性的文字
-
问题(question):与提示相同的文字
有时候数据中,“prompt” 和 “question” 可能是相同的,这在某些训练设置中可能用作对话的起始点。
接下来,我结合训练数据,大致介绍DPO训练的过程和结果。
DPO(直接偏好优化)的核心思想
1.DP的核心目标
-
目标: 在不显式训练奖励模型的情况下,直接利用人类偏好数据对模型进行优化。
-
参考模型的引入: 为了防止模型在优化过程中偏离原有的语言能力,DPO 引入了参考模型(通常是初始模型的副本,参数固定不变),作为正则化项。
2. 训练数据
-
提示(Prompt): 用户输入,例如:“请解释水的三态变化。”
-
被选回复(Chosen Reply): 人类评估为高质量、符合预期的回复。
-
被拒回复(Rejected Reply): 人类评估为质量较低、不符合预期的回复。
3. 训练过程
步骤 1:计算对数概率
对于当前模型(参数为 θ):
-
被选回复的对数概率:
`log_p_model_chosen = log( π_θ( chosen_reply | prompt ) )`
-
被拒回复的对数概率:
`log_p_model_rejected = log( π_θ( rejected_reply | prompt ) )`
对于参考模型(参数固定):
-
被选回复的对数概率:
`log_p_ref_chosen = log( π_ref( chosen_reply | prompt ) )`
-
被拒回复的对数概率:
`log_p_ref_rejected = log( π_ref( rejected_reply | prompt ) )`
步骤 2:计算偏好差值
-
被选回复的偏好差值:
`delta_chosen = log_p_model_chosen - log_p_ref_chosen`
-
被拒回复的偏好差值:
`delta_rejected = log_p_model_rejected - log_p_ref_rejected`
步骤 3:构建损失函数
-
损失函数的形式:
`L(θ) = -log( exp( delta_chosen / β ) / [ exp( delta_chosen / β ) + exp( delta_rejected / β ) ] )`
其中,β 是控制温度的超参数。
-
目标: 最小化损失函数 ( L(θ) ),使得模型更倾向于生成被选回复,而不是被拒回复。
4. 示例
以您的例子为例:
-
Prompt(提示): “请解释水的三态变化。”
-
Chosen Reply(被选回复):
“水有三种状态:固态、液态和气态。温度的变化会导致水在这些状态之间转化,例如冰融化成水,水蒸发成水蒸气。”
-
Rejected Reply(被拒回复):
“水是一种液体,在自然界中很常见。”
步骤:
- 计算对数概率
-
log_p_model_chosen = -5
-
log_p_model_rejected = -7
-
log_p_ref_chosen = -6
-
log_p_ref_rejected = -6
-
假设性的数值(用于说明):
- 计算偏好差值
-
delta_chosen = -5 - (-6) = 1
-
delta_rejected = -7 - (-6) = -1
- 计算损失函数(β = 1)
-
计算分子:
`exp( delta_chosen / β ) = exp(1) ≈ 2.718`
-
计算分母:
`exp( delta_chosen / β ) + exp( delta_rejected / β ) = exp(1) + exp(-1) ≈ 2.718 + 0.368 ≈ 3.086`
-
计算损失:
`L(θ) = -log( 2.718 / 3.086 ) ≈ -log(0.880) ≈ 0.127`
-
损失较小,表示模型对被选回复的偏好已高于被拒回复。
- 优化模型参数
- 通过反向传播,最小化损失 ( L(θ) ),进一步提高模型对被选回复的偏好。
5. 参考模型的作用
-
正则化效果: 防止模型过度偏离初始语言模型,保证生成文本的质量和多样性。
-
稳定训练过程: 提供一个固定的对照,使模型的更新更加平稳,避免发生梯度爆炸或消失。
6. 总结
-
DPO 训练过程: 模型利用
prompt
、chosen
、rejected
和参考模型,直接优化自身,使其生成的回复更符合人类偏好。 -
参考模型不可或缺: 它在损失函数中提供了正则化项,确保模型在学习人类偏好的同时,保持原有的语言能力和知识。
AI大模型学习路线
如果你对AI大模型入门感兴趣,那么你需要的话可以点击这里大模型重磅福利:入门进阶全套104G学习资源包免费分享!
扫描下方csdn官方合作二维码获取哦!
这是一份大模型从零基础到进阶的学习路线大纲全览,小伙伴们记得点个收藏!
第一阶段: 从大模型系统设计入手,讲解大模型的主要方法;
第二阶段: 在通过大模型提示词工程从Prompts角度入手更好发挥模型的作用;
第三阶段: 大模型平台应用开发借助阿里云PAI平台构建电商领域虚拟试衣系统;
第四阶段: 大模型知识库应用开发以LangChain框架为例,构建物流行业咨询智能问答系统;
第五阶段: 大模型微调开发借助以大健康、新零售、新媒体领域构建适合当前领域大模型;
第六阶段: 以SD多模态大模型为主,搭建了文生图小程序案例;
第七阶段: 以大模型平台应用与开发为主,通过星火大模型,文心大模型等成熟大模型构建大模型行业应用。
100套AI大模型商业化落地方案
大模型全套视频教程
200本大模型PDF书籍
👉学会后的收获:👈
• 基于大模型全栈工程实现(前端、后端、产品经理、设计、数据分析等),通过这门课可获得不同能力;
• 能够利用大模型解决相关实际项目需求: 大数据时代,越来越多的企业和机构需要处理海量数据,利用大模型技术可以更好地处理这些数据,提高数据分析和决策的准确性。因此,掌握大模型应用开发技能,可以让程序员更好地应对实际项目需求;
• 基于大模型和企业数据AI应用开发,实现大模型理论、掌握GPU算力、硬件、LangChain开发框架和项目实战技能, 学会Fine-tuning垂直训练大模型(数据准备、数据蒸馏、大模型部署)一站式掌握;
• 能够完成时下热门大模型垂直领域模型训练能力,提高程序员的编码能力: 大模型应用开发需要掌握机器学习算法、深度学习框架等技术,这些技术的掌握可以提高程序员的编码能力和分析能力,让程序员更加熟练地编写高质量的代码。
LLM面试题合集
大模型产品经理资源合集
大模型项目实战合集
👉获取方式:
😝有需要的小伙伴,可以保存图片到wx扫描二v码免费领取【保证100%免费】🆓
