5.5.3 拒绝采样和监督微调
在推理导向的强化学习(RL)训练收敛后,接下来利用由此产生的检查点为下一轮收集监督微调(Supervised Fine-Tuning, SFT)数据。这一阶段的目标是进一步优化模型,使其不仅在推理任务上表现出色,还能在其他通用任务中表现良好。
1. 拒绝采样
(1)定义与目的
拒绝采样(Rejection Sampling)是一种蒙特卡洛方法,用于从复杂的目标概率分布中生成随机样本。当直接从目标分布中采样困难或不可行时,使用一个易于采样的提议分布,并根据某种接受概率来决定是否接受采样结果。在DeepSeek-R1中,拒绝采样的目的是从强化学习训练后的模型输出中筛选出高质量的样本,用于后续的监督微调。
(2)实现过程
- 选择提议分布:选择一个易于直接采样且覆盖目标分布支持的提议分布。
- 确定缩放常数:找到一个常数,使得对于所有的样本,提议分布的值不超过目标分布的值乘以该常数。
- 采样过程:先从提议分布中生成一个样本,然后从均匀分布中采样一个随机数;最后计算接受概率,如果随机数小于接受概率,则接受该样本;否则,拒绝该样本并重新采样。
2. 监督微调(Supervised Fine-Tuning, SFT)
(1)定义与目的
监督微调(SFT)是在预训练模型的基础上,使用标注数据进行进一步训练的过程。其目的是使模型在特定任务上表现得更好。在DeepSeek-R1中,SFT的目标是通过高质量的数据进一步优化模型,使其在推理任务和其他通用任务中都能表现出色。
(2)实现过程:
- 数据准备:使用拒绝采样生成的高质量数据,以及其他领域的数据(如写作、角色扮演等)。
- 模型微调:在这些数据上对模型进行微调,以提高模型在各种任务上的表现。
3. DeepSeek-R1推理数据的生成与筛选
DeepSeek团队策划了一系列推理提示,并通过拒绝采样(Rejection Sampling)从上述强化学习训练的检查点中生成推理轨迹。在之前的推理导向强化学习阶段,主要关注的是可以使用基于规则的奖励进行评估的数据。然而,在这一阶段进一步扩展了数据集,加入了使用生成式奖励模型评估的数据。具体来说,将真实值和模型预测输入DeepSeek-V3模型进行判断,以评估模型输出的质量。
此外,由于模型在某些情况下会生成混乱且难以阅读的输出,对生成的数据进行了严格的筛选操作。过滤掉了包含混合语言、长段落和代码块的链式推理,以确保数据的高质量和可读性。对于每个推理提示,采样了多个响应,并仅保留正确的响应。通过这种方式,总共收集了大约600,000个与推理相关的训练样本。
4. DeepSeek-R1非推理数据的生成与整合
除了推理数据外,还收集了非推理数据,包括写作、事实问答、自我认知和翻译等任务。对于这些非推理任务,沿用了DeepSeek-V3的整体流程,并复用了DeepSeek-V3的部分监督微调(SFT)数据集。具体来说,这一流程包括以下几个关键步骤:
(1)数据生成与整合
- 对于某些非推理任务,通过提示调用DeepSeek-V3生成潜在的思维链(Chain of Thought, CoT),以便在回答问题之前提供更详细的推理过程。
- 对于更简单的查询(例如“你好”),不会提供推理链作为回应,因为这些任务不需要复杂的推理。
(2)数据筛选与优化
对生成的数据进行了筛选,确保数据的质量和可读性。例如,过滤掉了混合语言、长段落和代码块的输出。最终,收集了大约20万个与推理无关的训练样本。
3. DeepSeek-R1的监督微调(SFT)
监督微调(Supervised Fine-Tuning, SFT)是指在预训练模型的基础上,使用标注数据进行进一步训练的过程,旨在使模型在特定任务上表现得更好。在 DeepSeek-R1 的训练流程中,SFT 起到了至关重要的作用,不仅进一步优化了模型在推理任务上的表现,还增强了其在其他通用任务中的能力。
(1)使用策划的数据集进行微调
DeepSeek-R1 使用了约 80 万个样本的数据集进行微调,这些样本包括推理和非推理数据。具体说明如下所示:
- 推理数据:通过拒绝采样从强化学习训练后的模型中生成推理轨迹。对于每个提示,生成多个响应,并保留正确的响应作为训练样本。这一过程扩展了数据集,包含使用生成奖励模型的数据,过滤掉混合语言、长段落和代码块,最终收集约 60 万个推理相关训练样本。
- 非推理数据:采用 DeepSeek-V3 的 Pipeline,重用部分 SFT 数据,针对某些任务生成潜在的思维链。对于更简单的查询,例如“hello”,不提供 CoT 作为响应。最终收集约 20 万个非推理训练样本。
(2)微调过程
- 数据准备:使用上述约 80 万个样本的精选数据集对 DeepSeek-V3-Base 进行两个 epoch 的微调。
- 模型微调:在这些数据上对模型进行微调,以提高模型在各种任务上的表现。具体来说,使用 LoRA 适配器应用于关键投影层,从而减少微调期间的内存和计算要求。例如配置和运行训练过程的代码如下:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
dataset_num_proc=2,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=60,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=10,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
),
)
trainer_stats = trainer.train()
(3)微调后的效果
- 推理任务:通过推理数据的微调,模型在推理任务上的表现得到了进一步优化。例如,在数学问题、编码任务和科学实验分析等推理密集型任务中,模型的逻辑思维能力和问题解决能力得到了显著提升。
- 通用任务:通过非推理数据的微调,模型在写作、角色扮演、事实问答等通用任务中的表现也得到了增强。这使得模型不仅在推理任务上表现出色,还能在其他任务中提供高质量的输出。
总之,通过使用约 80 万个样本的数据集对 DeepSeek-V3-Base 进行两个 epoch 的微调,DeepSeek-R1 模型不仅在推理任务上表现出色,还在其他通用任务中表现良好。这一过程不仅进一步优化了模型的推理能力,还增强了其在多种任务中的通用性,为后续的应用奠定了坚实的基础。