人工反馈强化学习(RLHF)

论文:Training language models to follow instructions

with human feedback,地址:https://arxiv.org/pdf/2203.02155

1 简介

使语言模型更大并不能从本质上使它们更好地遵循用户的意图。利用人工反馈进行微调是使语言模型与人类意图相一致的一个有希望的方向。

专注于微调对齐语言模型的方法。使用来自人工反馈的强化学习(RLHF)微调GPT-3。这种技术使用人类的偏好作为奖励信号来微调模型。收集了一个数据集,用于在更大的API提示集上对模型的输出进行人工标记的比较。在数据集上训练一个奖励模型(RM),以预测标签器更喜欢输出哪个模型。使用此RM作为奖励函数,并使用PPO算法微调监督学习基线以最大化此奖励。这个过程使GPT-3的行为与特定人群(主要是标签员和研究人员)的声明偏好保持一致,而不是任何更广泛的“人类价值观”概念;将结果模型称为InstructGPT。过程如下图:

InstructGPT模型对RLHF微调分布之外的指令显示出很好的泛化能力。模型能够概括“遵循指令的概念。“他们即使在很少得到直接监督信号的任务上也能保持一定的一致性。

2 相关工作

从人类反馈中强化学习(RLHF)。应用于微调语言模型以总结文本。使用人工反馈作为奖励,如对话,翻译,语义解析,故事生成,评论生成和证据提取。使用书面人工反馈来增强提示并提高GPT-3的性能。还有一些工作在基于文本的环境中使用RL与规范先验对齐智能体。工作可以被视为RLHF在广泛分布的语言任务上对语言模型的直接应用。

修改语言模型的行为以减轻危害。 有许多方法可以改变语言模型的生成行为。在一个小的、以价值为目标的数据集上对LMs进行微调,这提高了模型在问答任务上坚持这些价值的能力。通过删除具有高条件可能性的语言模型生成一组研究人员编写的触发短语的文档来过滤预训练数据集。当在这个过滤后的数据集上训练时,他们的语言模型生成的文本危害更小,代价是语言建模性能略有下降。使用各种方法来提高聊天机器人的安全性,包括数据过滤,在生成过程中屏蔽某些单词或n元语法,特定于安全的控制令牌,以及人工在环数据收集。其他减轻LMs生成偏差的方法使用词嵌入正则化、数据增强、零空间投影以使敏感标记的分布更均匀、不同的目标函数或因果调解分析。还有一项工作是使用第二个(通常较小)语言模型来指导语言模型的生成,该想法的变体已应用于减少语言模型的毒性。

3方法和实验细节

3.1高级方法论

从一个预训练的语言模型开始,希望模型在其中产生对齐输出的提示分布,以及一个训练有素的人工标签团队。然后执行以下三个步骤:

步骤1:收集演示数据,并训练一个监督策略。 标记器提供了输入提示分布上所需行为的演示。然后,使用监督学习在该数据上微调预训练的GPT-3模型。

步骤2:收集对比数据,并训练奖励模型。 收集了一个模型输出之间比较的数据集,其中标记者表明他们更喜欢给定输入的输出。然后,训练一个奖励模型来预测人类首选的输出。

步骤3:使用PPO根据奖励模型优化策略。 使用RM的输出作为标量奖励。使用PPO算法微调监督策略以优化此奖励。

步骤2和3可以连续迭代;在当前最佳策略上收集更多的比较数据,用于训练一个新的RM,然后再训练一个新的策略。在实践中,大多数比较数据来自监督策略,还有一些来自PPO策略。

3.2 数据集

prompt数据集主要由提交给OpenAI API的文本提示组成。

为了训练第一个InstructGPT模型,要求标记者自己编写提示。这是因为需要一个类似指令的提示的初始来源来引导过程,而这些类型的提示通常不会提交给API上的常规GPT-3模型。要求标签员写出三种提示:

• 简单:只是要求标记者提出一个任意的任务,同时确保任务具有足够的多样性。

• 少样本:要求标签者提出一条指令,以及该指令的多个查询/响应对。

• 基于用户:在OpenAI API的waitlist应用程序中有许多用例。要求标记者提出与这些用例相对应的提示。

从这些提示中 ,产生了三个不同的数据集 ,用于微调过程:(1)SFT数据集 ,用labeler演示来训练SFT模 型 ,(2)RM数 据 集 ,用labeler排名来训RMs,以及(3)PPO数据集,没有任何人工标签,这些数据集用作RLHF微调的输入。SFT数据集包含大约13k个训练提示(来自API和labeler-written),RM数据集有33k个训练提示(来自API和labeler-written),PPO数据集有31k个训练提示(仅来自API)。

3.3 任务

训练任务来自两个来 源:(1)标记者编写的提示数据集和(2)提交给API上早期InstructGPT模型的提示数据集。这些提示非常多样化,包括生成、问题回答、对话、摘要、提取和其他自然语言任务。

3.4 人类数据收集

为了提供演示和比较数据,并进行主要评估,在Upwork和ScaleAI上雇用了大约40名承包商。与早期在摘要任务上收集人类偏好数据的工作相比,所提出的输入跨越了更广泛的任务,并且偶尔可以包括有争议和敏感的主题。目标是选择一组标签者,这些标签者对不同人口群体的偏好敏感,并且善于识别潜在有害的输出。因此,进行了一个筛选测试,旨在测量标记器在这些轴上的性能。选择了在这个测试中表现良好的标签者;

在训练和评估期间,对齐标准可能会发生冲突:例如,当用户请求一个潜在有害的响应时。在训练期间,优先考虑对用户的有用性(不这样做需要做出一些困难的设计决策,将这些决策留给未来的工作)。然而,在最终评估中,要求标签者优先考虑真实性和无害性(因为这是真正关心的)。

3.5 模型

从GPT-3预训练语言模型开始。这些模型在广泛分布的互联网数据上进行训练,并能适应广泛的下游任务,但具有较差的行为特征。从这些模型开始,用三种不同的技术训练模型:

监督微 调(SFT)。 在标签器演示上使用监督学习对GPT-3进行微调 。 训练了16个epoch,使用余弦学习率衰减,残差下降为0.2。根据验证集上的RM分数进行最终的SFT模型选择。发现SFT模型在1次迭代后过拟合验证损失;尽管有这种过拟合,但对更多epoch的训练对RM分数和人类偏好评级都有帮助。

奖励模型(RM)。 从最终的去嵌入层的SFT模型开始删除后,训练了一个模型来接收提示和响应,并输出标量奖励。只使用6B RM,因为这节省了大量的计算,并且发现175B RM训练可能不稳定,因此不太适合用作RL期间的值函数。

RM是在对相同输入的两个模型输出进行比较的数据集上进行训练的。他们使用交叉熵损失,将比较作为标签——奖励的差异代表人类标记者更喜欢一种反应的对数概率。

为了加快比较收集,我们为标签者提供K= 4和K = 9之间的任何位置的排名响应。这将为显示给标记者的每个提示生成  比较。由于在每个标记任务中,比较是非常相关的,发现如果简单地将比较混洗到一个数据集,对数据集的一次遍历会导致奖励模型过拟合。相反,将每个提示作为单个批处理元素训练所有  比较这在计算上更加高效,因为它只需要对每个完成进行一次RM的前向传递(而不是对K完成进行 前向传递),并且,因为它不再过度拟合,它实现了大大提高的验证准确性和日志损失。

具体来说,奖励模型的损失函数为:

中rθ (x, y)是提示x和带参数θ的完成y的奖励模型的标量输出,yw 是yw 和yl 对中的首选完成,D是人工比较的数据集。

最后,由于RM损失对奖励的偏移是不变的,我们使用偏差对奖励模型进行归一化,以便在进行强化学习之前,标记器演示达到平均分数0。

强化学习(RL)。使用PPO在环境上微调SFT模型。环境是一个bandit环境,它呈现一个随机的客户提示并期望对该提示作出响应给定提示和响应,它由奖励模型生成一个奖励,并结束这一回合。在每个标记(token)处添加一个来自监督微调(SFT)模型的每个标记的KL散度(KL - divergence)惩罚项,以减轻奖励模型的过度优化。。value函数由RM初始化。称这些模型为‘PPO’。”

还实验了将预训练梯度混合到PPO梯度中,以解决在公共NLP数据集上的性能回归问题。称这些模型为‘PPO-ptx ‘。“在强化学习训练中最大化以下组合目标函数:

其中 是学习到的强化学习策略, 是监督训练的模型, 是预训练分布。KL奖励系数β和预训练损失系数γ分别控制KL惩罚和预训练梯度的强度。对于“PPO”模型,γ设置为0。除非另有说明,这里InstructGPT指的是PPO-ptx模型。

基线。 将PPO模型的性能与SFT模型和GPT-3进行了比较。还与GPT-3进行了比较,当它提供了一个少次数的前缀来"提示"它进入指令遵循模式(GPT-3提示)。这个前缀放在用户指定的指令前面。

3.6 评价

为了评估模型“对齐”的程度,首先需要澄清在这种情况下对齐的含义。对齐的定义历来是一个模糊和令人困惑的话题,有各种竞争性的建议。目标是训练符合用户意图的模型。更实际地说,为了语言任务的目的,使用类似于Askell et al. (2021)的框架,它将模型定义为一致的。

可以将量化评估分为两个独立的部分:

API分布评估。 主要指标是人类对来自与训练分布相同来源的一组提示的偏好评级。

在公共NLP数据集上的评估。 在两种类型的公共数据集上进行了评估:捕捉了语言模型安全性的某一方面,特别是真实性、毒性和偏见,以及在问答、阅读理解和摘要等传统NLP任务上的零样本表现。

4 结果

为1节中的主张提供了实验证据,分为三部分:API提示分布的结果,公共NLP数据集的结果和定性结果。

4.1 API分布结果

标记者明显更喜欢InstructGPT输出而不是GPT-3的输出。 在测试提示集上,标注者明显偏好InstructGPT跨模型大小的输出。

发现当在API上提交给GPT-3模型的提示上进行评估时,结果没有显著变化(参见图3),尽管PPO-ptx模型在更大的模型尺寸上表现略差。

所提出模型可以泛化到没有产生任何训练数据的"保留"标签者的偏好。 保留标签的人与用来生成训练数据的工人有相似的排名偏好(参见图3)。

公开的NLP数据集不能反映语言模型是如何使用的 。 在图5中 , 将InstructGPT与在FLAN (Wei et al., 2021)和T0 (Sanh et al., 2021)数据集上微调的175B GPT-3基线进行了比较。这些模型的表现优于GPT-3,与精心选择的提示符的GPT-3相当,但比SFT基线差。这表明这些数据集的多样性不足以提高我们的API提示分布的性能。在直接比较中,175B InstructGPT模型输出在4%的时间内优于FLAN模型78 ±,在4%的时间内优于T0模型79 ±。这些模型的李克特分数如图5所示。

认为InstructGPT模型优于FLAN和T0有两个原因。首先,设计了公开的NLP数据集,以捕获易于用自动指标评估的任务,如分类、问答,以及在一定程度上的摘要和翻译。然而,分类和QA只是API客户使用语言模型的一小部分(约18%),而开放式的生成和头脑风暴占了提示数据集的57%。其次,公共NLP数据集很难获得非常高的输入多样性(至少在现实世界的用户会感兴趣的输入类型上)。当然,在NLP数据集中发现的任务确实代表了一种希望语言模型能够解决的指令,因此最广泛的指令遵循模型将结合两种类型的数据集。

4.2 在公共NLP数据集上的实验结果

InstructGPT模型比GPT-3在真实性方面有所改进。 通过对TruthfulQA数据集的人工评估,与GPT-3相比,PPO模型在生成真实和有信息量的输出方面显示出了微小但显著的改进(见图6)。这种行为是默认的:模型不需要明确指示说真话,就可以表现出更好的真实性。

InstructGPT在毒性方面比GPT-3有小的改进 ,但没有偏差。

结果见图7。根据Perspective API,当指示产生安全且令人尊重的输出("尊重的提示")时,InstructGPT模型产生的毒性比GPT-3的输出更小。当删除尊重提示符(“没有提示符”)时,这个优势就消失了。有趣的是,当明确提示产生有害输出时,InstructGPT的输出比来自GPT-3的输出更有害。

为了评估模型生成有偏语音的倾向,还在修改版本的Winogender (Rudingeret al., 2018)和CrowS-Pairs (Nangia et al., 2020)数据集上评估了InstructGPT。这些数据集由一对句子组成,可以突出潜在的偏见。计算每对中生成句子的相对概率,以及相关二进制概率分布的熵(以比特为单位)。完全无偏的模型在每对句子之间没有偏好,因此具有最大的熵。根据这个指标,模型的偏差并不比GPT-3小。PPO-ptx模型显示出与GPT-3相似的偏差,但当被指示遵守行为时,它表现出更低的熵,因此偏差更高。偏差的模式不清楚;似乎被指导的模型对其输出更确定,无论其输出是否表现出刻板行为。

可以通过修改RLHF微调程序来最小化公共NLP数据集上的性能回归。 默认情况下,当在API分布上训练PPO模型时,它会受到“对齐税”的影响,因为它在几个公共NLP数据集上的性能下降。想要一个避免对齐税的对齐程序,因为它鼓励使用未对齐但在这些任务上更有能力的模型。

PPO微调(PPO-ptx)中添加预训练更新 , 缓解了所有数据集上的这些性能回归 ,甚至超 过了HellaSwag上的GPT-3PPO-ptx模型的性能DROP,SQuADv2translation方面仍然落后于GPT-3;需要做更多的工作来研究和进一步消除这些性能回归。

混合预训练更新的性能比增加KL系数的简单解决方案要好。预训练混合系数的一个值,该值可以逆转SQuADv2DROP(用于测试的数据集)上的性能回归,并且验证奖励的减少最小。相反,增加KL系数会导致验证奖励显著下降,并且在掉落和阵容上永远无法完全恢复。将KL模型从PPO init改为GPT-3得到了类似的结果。

4.3 定性结果

InstructGPT模型对RLHF微调分布之外的指令显示出很好的泛化能力。 InstructGPT显示

了遵循非英语语言指令的能力,并为代码进行摘要和问答。

InstructGPT仍然会犯一些简单的错误。 在与175B PPO-ptx模型交互时,注意到它仍然可能犯简单的错误,尽管它在许多不同的语言任务上表现强劲。举几个例子:(1)当给出一个带有错误前提的指令时,模型有时会错误地假设前提为真,(2)模型可能会过度模糊;当给出一个简单的问题时,它有时会说问题没有一个答案,并给出多个可能的答案,即使上下文中有一个相当明确的答案。(3)当指令包含多个显式约束时(例如,模型的性能会下降。”列出20世纪30年代在法国拍摄的10部电影”)或当约束对语言模型来说可能具有挑战性时(例如,用指定数量的句子编写摘要)

在图9中展示了这些行为的一些示例。怀疑行为(2)的出现部分是因为指示标签者奖励认知上的谦逊;因此,它们可能倾向于奖励对冲的输出,这被奖励模型所接受。怀疑行为(1)发生是因为训练集中很少有假设为假前提的提示,并且模型不能很好地泛化到这些示例。相信,通过对抗性数据收集(Dinan et al., 2019b),这两种行为都可以显著减少。

5讨论

5.1对对齐研究的影响

这项研究是使人工智能系统与人类意图保持一致。尽管这项工作专注于当前的语言模型系统,但寻求适用于未来AI系统的通用和可扩展方法。在这里使用的系统仍然相当有限,但它们是当今最大的语言模型之一,将它们应用于广泛的语言任务,包括分类、摘要、问答、创意写作、对话等。

对齐研究的方法在这项工作中是迭代的:正在改进当前人工智能系统的对齐,而不是抽象地关注对齐尚不存在的人工智能系统。这种方法的一个缺点是,不会直接面临只有在对齐超人系统时才出现的对齐问题。然而,方法确实提供了一个明确的经验反馈回路,告诉什么有效,什么无效。相信这种反馈循环对完善对齐技术至关重要,它迫使跟上机器学习的步伐。此外,在这里使用的对齐技术RLHF,是几个对齐超人系统建议中的重要构建块。例如,RLHF是最近关于书籍摘要的工作中的核心方法,该任务表现出了调整超人AI系统的一些困难,因为人类很难直接评估。

从这项工作中,可以从更广泛的对齐研究中吸取教训:

1. 相对于预训练,增加模型对齐的成本是适中的。收集数据和用于训练运行(包括实验运行)的计算成本只是训练GPT-3花费的一小部分。同时,我们的结果表明,RLHF在使语言模型对用户更有帮助方面非常有效,比模型大小增加100倍更有效。这表明,目前增加对现有语言模型对齐的投资比训练更大的模型更具成本效益——至少对于客户的自然语言任务分布来说是这样。

2. 已经看到了一些证据,InstructGPT将“遵循指令”泛化到不监督的设置中,例如在非英语语言任务和与代码相关的任务中。这是一个重要的属性,因为让人工在模型执行的每个任务上监督模型的成本非常高。需要进行更多的研究,以研究这种泛化能力随着能力的提高如何扩展。

3. 能够缓解微调带来的大多数性能下降。如果不是这样,这些性能下降将构成对齐税——对齐模型的额外成本。任何高税收的技术可能都不会被采用。为了避免激励未来的高能力AI系统保持与人类意图不一致,需要具有低对齐税的对齐技术。为此,结果对于RLHF作为一种低税收对齐技术来说是一个好消息。

4. 已经从现实世界的研究中验证了对齐技术。对齐研究历来相当抽象,侧重于理论结果,小型合成领域,或在公共NLP数据集上训练ML模型。该工作为在现实世界的生产中与客户一起使用的人工智能系统的对齐研究提供了基础。这使一个关于技术的有效性和局限性的重要反馈循环成为可能。

5.2 与谁对齐?

当将语言模型与人类意图对齐时,它们的最终行为是基础模型(及其训练数据)、微调数据和

所使用的对齐方法的函数。在这里,将具体描述影响微调数据的一些因素,以最终确定要与什么人和谁保持一致。

这些文献经常使用“人类偏好”或“人类价值观”等术语来构建对齐。“在这项工作中,与

一组标签者的偏好保持一致,这些偏好受到给予他们的指示、他们接收这些指示的环境(为一项有偿工作)以及他们从谁那里接收这些指示的影响。一些关键的注意事项:

首先,将与训练标记器提供的演示和偏好保持一致,这些标记器直接产生用于微调模

型的数据。

其次,作为设计这项研究的研究人员(因此通过代理到更广泛的研究组织OpenAI),正在调整偏好:编写标记说明,标记人员在编写演示和选择他们喜欢的输出时使用指导,在共享聊天室中回答他们关于边界情况的问题。

第三,训练数据是由OpenAI客户发送给OpenAI API Playground上的模型的提示确定

的,因此隐式地对齐了客户认为有价值的东西,在某些情况下,他们的终端用户认为当前使用API有价值的东西。

第四,OpenAI的客户不能代表语言模型的所有潜在或当前用户——更不用说所有受语言模

型使用影响的个人和群体。

目标是证明这种对齐技术可以与特定应用的特定人类参考组对齐。

一种方法是训练模型,这些模型可以根据特定群体的偏好进行限制,或者可以很容易地进

行微调或提示以代表不同的群体。然后,支持不同价值观的团队可以部署和使用不同的模

型。

5.3 局限性

方法论。 InstructGPT模型的行为部分由从承包商获得的人工反馈决定。一些标注任务依

赖于价值判断,而价值判断可能会受到承包商的身份、信仰、文化背景和个人历史的影

响。雇用了大约40名承包商,根据他们在筛选测试中的表现来指导,该测试旨在判断

他们识别和响应敏感提示的能力,以及他们与研究人员在一项有详细说明的标签任务上的

一致率。承包商团队规模较小,因为这有助于与一小部分全职完成任务的承包商进行高带宽沟通。然而,这个群体显然不能代表所有将使用部署的模型并受其影响的人。举个简单的例子,标注者主要是说英语的,数据几乎完全由英语指令组成。

还有许多方法可以改善我们的数据收集设置。

模型。 模型既不是完全对齐的,也不是完全安全的;他们仍然会产生有害或有偏见的输出,编造事实,并在没有明确提示的情况下产生性和暴力内容。

也许模型最大的限制是,在大多数情况下,它们遵循用户的指示,即使这可能会在现实世界中导致伤害。

5.4 开放性问题

这项工作是使用对齐技术微调语言模型以遵循广泛指令的第一步。有许多开放问题需要探

索,以进一步使语言模型的行为与人们实际希望他们做的事情保持一致。

可以尝试许多方法来进一步降低模型产生有毒、有偏差或其他有害输出的倾向。

在这项工作中,如果用户请求一个潜在有害或不诚实的响应,允许模型生成这些输出。

让模型做想做的事情与可控性和可控性文献直接相关。一个有前途的未来路径是将RLHF与其他可操控性方法相结合,虽然主要关注RLHF,但还有许多其他算法可以用于在演示和比较数据上训练策略,以获得更好的结果。例如,可以探索专家迭代,或者使用比较数据子集的更简单的行为克隆方法。人们还可以尝试约束优化方法,从以生成少量有害行为为条件的奖励模型中最大化分数。

比较也不一定是提供对齐信号的最有效方法。例如,可以让标记者编辑模型响应以使其更好,或用自然语言生成对模型响应的批评。

通过将预训练数据合并到RLHF微调中来减轻对齐税的建议,并不能完全缓解性能回归,并

可能使某些任务更有可能出现某些不期望的行为(如果这些行为存在于预训练数据中)。这是

一个值得进一步研究的有趣领域。另一个可能改进方法的修改是过滤预训练混合数据中的有毒内容,或者用合成指令增强此数据。

在对齐指令、意图、显示的偏好、理想的偏好、兴趣和价值观之间存在微妙的差异。Gabriel提倡一种基于原则的一致性方法:换句话说,就是确定“尽管人们的道德信仰存在广泛差异,但仍能得到反思性认可的合理的一致性原则”。在这里,与推断出的用户意图保持一致,但在这一领域还需要更多的研究。

5.5更广泛的影响

目的是通过训练大型语言模型做一组给定的人类希望它们做的事情,来增加它们的积极影响。

对齐技术不是解决与大型语言模型相关的安全问题的万灵药;相反,它们应该被用作更广泛的安全生态系统中的一种工具。

模型与谁对齐的问题非常重要,并且将显著影响这些模型的净影响是积极的还是消极的。

代码实现:

PaLM-rlhf-pytorch源码地址:GitHub - lucidrains/PaLM-rlhf-pytorch: Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM

下载代码后,安装依赖包:

这里可用Python3.9

$ pip install "numpy<2"

下载Pytorch2.0.whl安装文件,下载地址:https://mirrors.aliyun.com/pytorch-wheels/cu117/?spm=a2c6h.25603864.0.0.5bed6223UN87JG

$ pip install /home/heyiqiu/下载/torch-2.0.0+cu117-cp39-cp39-linux_x86_64.whl

$ pip install beartype

在Pycharm中导入上面下载的源码项目,配置虚拟环境,保证安装了需要的依赖包,保证编译不报错。

这里以一个名叫PaLM的模型为例,实际中可以换成自己要训练的模型。

进行SFT训练:

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    flash_attn = True # https://arxiv.org/abs/2205.14135
).cuda()

seq = torch.randint(0, 20000, (1, 2048)).cuda()

loss = palm(seq, return_loss = True)
loss.backward()

# after much training, you can now generate sequences

generated = palm.generate(2048) # (1, 2048)

训练奖励模型:
import torch
from palm_rlhf_pytorch import PaLM, RewardModel

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    causal = False
)

reward_model = RewardModel(
    palm,
    num_binned_output = 5 # say rating from 1 to 5
).cuda()

# mock data

seq = torch.randint(0, 20000, (1, 1024)).cuda()
prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).cuda()

# train

loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()

# after much training

reward = reward_model(seq, prompt_mask = prompt_mask)
进行人工反馈强化学习:
import torch
from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer

# load your pretrained palm

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12
).cuda()

palm.load('./path/to/pretrained/palm.pt')

# load your pretrained reward model

reward_model = RewardModel(
    palm,
    num_binned_output = 5
).cuda()

reward_model.load('./path/to/pretrained/reward_model.pt')

# ready your list of prompts for reinforcement learning

prompts = torch.randint(0, 256, (50000, 512)).cuda() # 50k prompts

# pass it all to the trainer and train

trainer = RLHFTrainer(
    palm = palm,
    reward_model = reward_model,
    prompt_token_ids = prompts
)

trainer.train(num_episodes = 50000)

# then, if it succeeded...
# generate say 10 samples and use the reward model to return the best one

answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)

另外一个著名的轻量级分布式RLHF框架OpenRLHF,代码地址:https://github.com/OpenRLHF/OpenRLHF
采用OpenRLHF著名公司有:百度、腾讯、阿里巴巴、谷歌、中国电信等。

OpenRLHF用DeepSpeed实现分布式训练。
OpenRLHF用FlashAttention技术。FlashAttention是一种优化Transformer模型中注意力机制的技术,旨在提高计算效率并减少内存使用。‌ 它通过重新排序注意力计算,无需任何近似即可加速注意力计算并减少内存占用,特别适合用于大型语言模型的加速‌。
FlashAttention的核心原理是通过将输入分块并在每个块上执行注意力操作,从而减少对高带宽内存(HBM)的读写操作。它利用底层硬件的内存层次知识,如GPU的内存层次结构,通过将输入块从HBM加载到SRAM(快速缓存)上执行注意力操作,并将结果更新回HBM,从而减少了内存读写量,实现了2-4倍的时钟时间加速‌。
FlashAttention已经在多个知名大型语言模型(LLM)中应用,包括GPT-3、Falcon2、Llama2、Megatron-LM和GPT-4等。最新的FlashAttention-2版本进一步优化了算法,使用了更好的并行化和工作分区方法,提高了计算速度并支持更高的头维数和多查询注意力等新特性。

Python环境用python3.10, torch2.4
Python | flash_attn 模块安装指南:
https://www.iotword.com/23757.html
https://github.com/Dao-AILab/flash-attention/releases页面下载:flash_attn-2.6.3+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
$ pip install /home/heyiqiu/下载/flash_attn-2.6.3+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

OpenRLHF代码中默认是用Meta-Llama-3-8B预训练模型的,可以根据需要换成自己的模型及训练数据。
OpenRLHF中的后训练也分三步:监督微调、训练奖励模型、近端策略优化,在/OpenRLHF-main/examples/scripts中提供了这三步训练.sh脚本:train_sft_llama.sh
、train_rm_llama.sh、train_ppo_llama.sh。.sh脚本中超参数根据需要调整。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值