深度解析DeepSpeed-Chat RLHF:PPO阶段代码详解(2)

数据处理

这里和
DeepSpeed-Chat RLHF 阶段代码解读(1) —— 奖励函数阶段
处理基本一致,唯一的区别是输入不是 prompt + response,而是只有 prompt,response 靠 actor model 生成。

PPO

初始化

PPO 训练要进行模型初始化,这里一共四个模型:

Actor model、Reference model: 开始的时候这两个模型是一样的,用途是一样的。

Critic model、Reward model: 开始的时候这两个模型是一样的,但是用途是不一样的,一个是用来产生 critic value,一个是用来产生 reward 的,虽然结构是一样的。另外,critic model 会随着 ppo 训练更新,但是在 ppo 阶段,Reward model 是不变的,比较抽象。

Rewards

Token Level KL-Penalty

def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
                    action_mask):

    kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
    rewards = kl_divergence_estimate
    start = prompts.shape[1] - 1
    ends = start + action_mask[:, start:].sum(1) + 1
    reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                                self.clip_reward_value)
    batch_size = log_probs.shape[0]
    for j in range(batch_size):
        rewards[j, start:ends[j]][-1] += reward_clip[j]

    return rewards

这里按照参考 4 里说的,除了 eos token 对应的 responese 的 reward score,对其余的每个时间步都增加了一个正则项,因为正则项的格式就是
深度强化学习(DRL)算法 2 —— PPO 之 Clipped Surrogate Objective 篇
提到的 advantage 项为负的情况一模一样,只是这里不涉及 loss 的计算。因此,这里的目的:新策略和之前的策略不一致,增加探索,得到负的 KL 散度,从而提高奖励。

GAE
def get_advantages_and_returns(self, values, rewards, start):
    # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
    lastgaelam = 0
    advantages_reversed = []
    length = rewards.size()[-1]
    for t in reversed(range(start, length)):
        nextvalues = values[:, t + 1] if t < length - 1 else 0.0
        delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
        lastgaelam = delta + self.gamma * self.lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    # 逆序 advantage, 这样就按时间步顺序得到每个时间步的 advantage
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + values[:, start:]
    return advantages.detach(), returns

可以看到每个 rlhf 里的实现,和
DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读
的实现,没有本质上的区别。

Loss
  • actor loss

actor_loss 的核心是重要性采样,重要性采样的核心思想是:旧策略的采样在一定时间段内,可以用于新策略的训练,提高数据的有效利用。关于 loss 为什么可以写成 -ratio * advantage 的证明,可以看:

深度强化学习(DRL)算法 2 —— PPO 之 Clipped Surrogate Objective 篇

使用 torch.max 的原因,和 critic loss 一样,为了避免乐观估计,增加探索。

def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
    ## policy gradient loss
    log_ratio = (logprobs - old_logprobs) * mask
    ratio = torch.exp(log_ratio)
    pg_loss1 = -advantages * ratio
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                            1.0 + self.cliprange)
    pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
    return pg_loss

  • critic loss
def critic_loss_fn(self, values, old_values, returns, mask):
    ## value loss
    values_clipped = torch.clamp(
        values,
        old_values - self.cliprange_value,
        old_values + self.cliprange_value,
    )
    if self.compute_fp32_loss:
        values = values.float()
        values_clipped = values_clipped.float()
    vf_loss1 = (values - returns)**2
    vf_loss2 = (values_clipped - returns)**2
    # 这是损失函数的核心部分。首先,计算vf_loss1和vf_loss2中较大的那个值,然后乘以mask。
    # 这样做是为了只考虑有效的样本(由mask指示)。然后,取这个乘积的总和,除以mask中有效样本的数量(mask.sum()),得到平均损失。
    # 选取更大的 loss 是为了增加探索,防止过于乐观的估计。
    vf_loss = 0.5 * torch.sum(
        torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
    return vf_loss

RLHF 整体的流程

结合之前的文章,以及本篇文章的数据处理和 PPO 章节,相信读者对 RLHF 无论是原理和代码都有了一定的理解,这里再从整体梳理下使用 PPO 进行 RLHF 的流程。

step1 prompt 输入 actor model 得到 response

step2(重要性采样): prompt + response 分别输入到 actor_model 和 reference model 得到 log_probs、ref_log_probs、reward_score、values,这部分的数据可以重复利用。

step3 计算 critic_loss、actor_loss,更新 actor_model。

大致上,ppo 主要就这个三个步骤。

整个流程下来,我的感觉,很繁琐,难训练,所以目前主流大模型很少使用原始的这套 RLHF 流程,更多使用 dpo 算法,而且 RLHF 的数据有限,很难对所有的 response 有一个公平的 rewar,所以下一个系列文章会介绍利用 dpo 的 RLAIF 算法,如 SPIN、self-reward etc。欢迎关注。

参考

  1. Negative KL-divergence RLHF implementation · Issue #736 · huggingface/trl (github.com)
    [2307.04964] Secrets of RLHF in Large Language Models Part I: PPO (arxiv.org)
  2. 关于ppo阶段,reward分数计算的问题 · Issue #26 · OpenLMLab/MOSS-RLHF (github.com)
  3. 2009.01325v3.pdf (arxiv.org)
  • 14
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值