一、定义
- 理解GRPO算法
- GRPO 工作步骤
- 奖励的理解
- 目标函数的理解
- trl 源码中代码讲解
二、实现
参考:https://blog.csdn.net/yaohaishen/article/details/145282804
- 理解GRPO算法
在传统的强化学习中,模型(称为“策略模型”)会根据环境给出的奖励信号来调整自己的行为。通常,这会涉及一个额外的模型(称为“批评模型”),用来评估当前策略的好坏。然而,批评模型的训练既复杂又耗费计算资源。
GRPO 的核心思想是简化这个过程:它不需要批评模型,而是通过组内相对奖励来优化策略模型。具体来说,GRPO 会从当前策略中采样一组输出,然后根据这些输出的相对表现来调整策略,而不是依赖一个单独的批评模型。
即: 训练时不一定需要奖励模型,而是使用规则进行代替(见案例)。
GRPO 模型中的参考模型为自身(原生模型),因此,训练后的模型与训练前的模型差异不会很大。
2.GRPO 工作步骤
GRPO 的工作流程可以分为以下几个步骤:
1. 采样一组输出:
1.1 对于每个问题,GRPO 会从当前策略中采样一组输出(例如,生成多个不同的答案或推理过程)。
1.2 这些输出可以看作是模型对同一个问题的不同“尝试”。
2. 计算组内相对奖励:
2.1 对这组输出进行评分,计算每个输出的奖励(例如,答案是否正确、推理过程是否合理)。
2.2 然后,GRPO 会计算每个输出的相对优势,即它的奖励相对于组内其他输出的表现如何。
3. 优化策略模型:
3.1 根据这些相对优势,GRPO 会调整策略模型,使得表现较好的输出更有可能被生成,而表现较差的输出被抑制。
3.2这个过程通过数学公式(如梯度上升)来实现,逐步优化模型的策略。
GRPO 训练流程细化:
1. 初始化模型
1.1选择一个预训练的大型语言模型(如 DeepSeek-V3-Base)作为基础模型。
1.2 这个模型已经具备一定的语言理解和生成能力,但需要通过 RL 进一步优化其推理能力。
2 定义任务和奖励
2.1任务:例如数学问题求解、代码生成、逻辑推理等。
奖励:根据任务定义奖励函数。例如:准确性奖励:答案是否正确。
2.2格式奖励:推理过程是否遵循指定的格式(如 <think> 和 <answer> 标签)。
3 采样一组输出
3.1对于每个问题,从当前策略中采样一组输出(例如,生成多个不同的答案或推理过程)。
这些输出可以看作是模型对同一个问题的不同“尝试”。
4 计算组内相对奖励
4.1对这组输出进行评分,计算每个输出的奖励(例如,答案是否正确、推理过程是否合理)。
然后,计算每个输出的相对优势,即它的奖励相对于组内其他输出的表现如何。
5 优化策略模型
5.1根据这些相对优势,调整策略模型,使得表现较好的输出更有可能被生成,而表现较差的输出被抑制。
这个过程通过数学公式(如梯度上升)来实现,逐步优化模型的策略。
3.奖励的理解
4.目标函数的理解
高奖励输出 = 策略比率* 相对优势Ai
KL 散度是待训练模型与参考模型计算而得,保证训练模型与参考模型相似,避免训偏。GRPO 算法中,参考模型不用传,才用模型本身,从而使该模型与本身相似。
- trl 源码中代码讲解
#参考模型生成
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, num_logits_to_keep)
else:
with self.accelerator.unwrap_model(model).disable_adapter(): # 关闭适配器,原生模型作为参考模型
ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
#计算模型与参考模型的KL 散度
# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# Compute the rewards 计算奖励
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions 计算出奖励
rewards = rewards_per_func.sum(dim=1)
# Compute grouped-wise rewards 计算组内奖励均值和方差
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages 标准化奖励来计算优势
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) #优势标准化
#计算最后的损失函数
# x - x.detach() allows for preserving gradients from x #x - x.detach()允许从x中保留梯度
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
#per_token_logps - per_token_logps.detach() 的理解:策略比例 采用梯度代替
1. per_token_logps - per_token_logps.detach()的目的是在计算损失函数时,排除那些不需要梯度更新的部分,通常用于避免反向传播到某些特定的操作或变量。