【Preference Learning】Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning

1、问题背景

现有的使用偏好数据的方式有两种:一种是使用基于偏好数据构建奖励模型,另一种是直接使用偏好数据更新模型。传统的RLHF方法中,奖励模型是静态的、离线的。新提出的一种方式是“迭代”的使用偏好数据直接更新模型,它涉及一个循环的过程,会从当前策略开始,通过收集和分析数据来生成新的偏好数据,再使用新生成的数据来更新策略。这种方式更专注于模型的持续适应性,让模型更适用于人类决策和推理的复杂性。

AlphaZero结合神经网络、强化学习技术和蒙特卡洛搜索树取得了惊艳的效果。通过将MCTS作为策略来改善策略算子,提升LLM成为了一种可能。但在实际是线上将会面临两个挑战:一个挑战是需要确定MCTS的使用粒度。通常,偏好数据是在实例级别收集的,实例级的方法采用稀疏监督,这可能会丢失重要信息,并且可能无法改进MCTS的潜力。另一个挑战是MCTS对评论家或者奖励函数的依赖,这对MCTS生成的拓展信息提供有意义的反馈至关重要,指导其策略的改进。

2、本文方法

  1. 针对MCTS粒度问题:本文将实例级别的偏好数据分解成使用MCTS的步骤级别的数据来指导模型。
  2. 针对MCTS对评论家或者奖励函数的依赖:本文采用自我评估的方式来评估模型的输出情况对偏好数据进行标注,充当评论家或奖励模型来指导模型进行改进。这种自我评估的好处是简化了流程,提高中间步骤的一致性,确保更新迭代的模型鲁棒性更强和LLM更一致。
  3. 训练过程:动态生成文本,再通过基于MCTS的自我评估的反馈对偏好进行标注。然后,使用DPO算法利用偏好数据对LLM进行更新。

本文提出的方法在各种算数、常识性推理任务上进行广泛评估,观察到显著的性能改进。

在这里插入图片描述

本文方法从初始策略 π θ ( 0 ) \pi_{\theta(0)} πθ(0)和 prompt数据集 D p D_p Dp开始,每次迭代都设计从 D p D_p Dp中选择一批提示,模型在当前策略 π θ ( i − 1 ) \pi_{\theta(i-1)} πθ(i1)的指导下,为每个提示生成潜在响应。然后,我们会应用一组动态进化的奖励标准,从这些响应中提取偏好数据 D i D_i Di。模型的策略随后会使用这些偏好数据进行调优,对策略 π θ ( i ) \pi_{\theta(i)} πθ(i)进行更新。循环的进行采样、响应生成、偏好数据抽取和策略更新。从而持续的让模型进行自我改进并与不断更新的偏好数据保持一致。
在这里插入图片描述

文中使用MCTS作为近端策略改进算子,将当前的策略转换为改进策略。使用MCTS迭代地收集偏好数据,利用其前瞻性能力将实例级奖励分解为更细粒度的步骤级信号。为了增强中间步骤的一致性,文中结合了自我评估机制,不断的更新新生成数据的质量评估。在每轮迭代采样偏好数据期间,本文的方法可以使用MCTS去平衡数据质量开发和数据多样性的探索。该方法可以被视为DPO的在线版本,其中更新的策略被迭代地用于通过MCTS收集偏好数据。因此,该方法不仅能解决偏好数据收集和策略更新方面的问题,还引入了一个动态、迭代的框架,显著增强了LLM的推理能力。

(1)MCTS构建步骤级偏好数据
MCTS的过程从一个根节点开始,在三个迭代阶段展开,分别为:选择、扩展和备份。

  • 选择
    这一阶段的目标是确定能够平衡搜索质量和计算效率的节点。选择阶段由两个关键变量指导: Q ( s t , a ) Q(s_t,a) Q(st,a) N ( s t ) N(s_t) N(st)

其中, Q ( s t , a ) Q(s_t, a) Q(st,a)是在状态 s t s_t st下去做动作 a a a后可获得的价值。 N ( s t ) N(s_t) N(st)是访问状态 s t s_t st的访问频率。为了在探索新节点和利用已访问节点之间进行权衡,文中采用预测器+置信度上限的方式进行权衡。在节点 s t s_t st处,遵循下述公式:

在这里插入图片描述

其中, p ( a ∣ s t ) = π θ ( a ∣ x , s t ) / ∣ a ∣ λ p(a|s_t)=\pi_\theta(a|x,s_t)/|a|^\lambda p(ast)=πθ(ax,st)/∣aλ表示生成步骤a的策略 π θ \pi_\theta πθ的概率分布,通过 λ \lambda λ权重来惩罚长度,以防止推理链路过长。

  • 扩展
    在选择过程中,扩展发生在叶结点上,以整合新节点并评估奖励。在状态 s t s_t st下执行动作a获得的奖励值 r ( s t , a ) r(s_t,a) r(st,a)通过计算 R ( s t ) R(s_t) R(st) R ( s t + 1 ) R(s_{t+1}) R(st+1)的奖励值之差求得,可突出动作在状态 s t s_t st下的优势。如下述公式2所示,奖励计算融合了正确性 O O O和 自我评价 C C C这两个值,输出的结果中正确则为1,不正确则为-1,未完成的中间状态则为0。自我评价在公式3所示,其中A表示正确的选项在token级概率中的置信度得分。

在这里插入图片描述

通过选择和扩展的过程,直至达到最终状态(生成响应完成或者达到预设的最大高度)。

  • 备份
    到达终端状态后,从终端节点到根节点自下而上的更新,会更新访问次数N、状态值V和转换值Q:
    在这里插入图片描述

其中, γ \gamma γ是对未来状态的折现因子。

在响应生成的每一步,我们进行了K次MCTS迭代构建搜索树,同时更新Q值和访问次数。为了平衡构建树时的多样性、质量和效率。文中的搜索宽度会逐渐变窄,搜索宽度初始化为b1之后逐渐退火为更小的b2。文中使用对应于每个候选步骤的结构Q值来标记偏好,其中Q值大的会作为下一步的首选。对于深度为T的结果搜索树,我们会得到T对步骤级偏好数据。具体来说,在每层中,我们会选择最大Q值作为正例,最小Q值作为负例。在树的每层深度中选择父节点是最大值,这个值是访问次数与子节点访问次数的范围相乘,最大值会作为父节点。这种方式可表示出一代代中的质量和多样性。

(2)迭代更新学习
基于MCTS收集的步骤级偏好数据,使用DPO方式进行微调。考虑到Q值决定的偏好标签中的噪声,采用普通的DPO并使用MCTS中模拟的访问计数对每个偏好数据对应用适应标签平滑。使用简写 h π θ y w , y l = l o g π θ ( y w ∣ x ) π r e f ( y w ∣ x ) − l o g π θ ( y l ∣ x ) π r e f ( y l ∣ x ) h_{\pi_{\theta}}^{y_w,y_l}=log\frac{\pi_{\theta}(y_w|x)}{\pi_{ref(y_w|x)}}-log\frac{\pi_{\theta}(y_l|x)}{\pi_{ref(y_l|x)}} hπθyw,yl=logπref(ywx)πθ(ywx)logπref(ylx)πθ(ylx),在第i次迭代时,给定最新策略 π θ ( i − 1 ) \pi_{\theta}(i-1) πθ(i1)抽样的一批偏好数据 D i D_i Di,我们将策略目标 l i ( θ ) l_i(\theta) li(θ)表示为:

在这里插入图片描述

其中, y w y_w yw y l y_l yl分别代表偏好数据和非偏好数据,超参数 β \beta β表示KL约束。 α x , y w , y l \alpha_{x,y_w,y_l} αx,yw,yl是标签平滑变量,这个变量是由搜索树中偏好数据 y w , y l y_w, y_l yw,yl在相应状态下的访问次数计算得到:
在这里插入图片描述

其中, N ( x , y w ) N(x,y_w) N(x,yw) N ( x , y l ) N(x,y_l) N(x,yl)分别表示根据输入x做出的相应动作 y w y_w yw y l y_l yl的状态。

3、实验设置

(1)基座模型:Mistral-7B

(2)数据集:算术推理(GSM8K和MATH)、常识推理(ARC、AI2Science、OpenBook和CommonSenseQA)

在GSM8K数据集中评估了COT和POT推理能力,整合了GSM8K和MATH的训练数据构建偏好学习框架的prompt数据。

(3)baseline:

  • Self-Taught Reasoner:基于实例级别原理生成的迭代学习模型
  • Crystal:一种专注于常识推理中知识内省的强化学习微调模型
  • 直接微调:不使用思维链对基座模型进行微调
  • Language Model Self-Improvement:一种使用自一致性收集正例数据的自我训练方法
  • Math-Shepherd:将过程监督继承到PPO中

4、实验效果

(1)算数推理任务
本文中方法在GSM8K数据集上从75.9%增加到了81.8%,在MATH数据集上从28.9%增加到了34.7%。与同样在偏好学习中使用过程监督的Math-Shepherd相比,本文方法在不需要训练单独的奖励模型的情况下实现了类似的性能增益。

在这里插入图片描述

(2)常识推理任务
论文中的方法在ARC-Challenge (ARC-C)、AI2Sci-Middle (AI2Sci-M)和SciQ上分别实现了2.5%、3.0%和2.1%的绝对精度提高,超过了直接调优的结果。在OBQA和CSQA等任务中,论文中的方法侧重于中间推理的细化,与直接调优相比效率较低。尽管在监督微调(SFT)基线上有了显著的改进(例如,在OBQA上从59.8%提高到79.2%,在CSQA上从54.1%提高到74.8%),但与直接调优相比,这些改进并不大。这种差异可能归因于基础模型缺乏具体知识,其中引出中间推理链可能会在模型代中引入更多的不确定性,从而导致错误的预测。

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

辰阳星宇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值