Orca-Math:解锁小学生数学问题解答的潜能

人工智能咨询培训老师叶梓 转载标明出处

数学文字题解答被认为是SLMs面临的一个难题。以往的研究假设,要使模型在GSM8K基准测试上达到超过80%的准确率,模型参数至少需要34亿。近期,微软研究院的研究团队提出了一种新的方法——Orca-Math,旨在显著提升SLMs在解答小学生数学问题上的能力。Orca-Math通过创新的方法,使得一个仅有7亿参数的SLM达到了86.81%的准确率,这一成就无需复杂的集成学习、代码执行或任何其他外部工具。

不同模型在GSM8K基准测试上的性能结果,直观比较Orca-Math模型与其他模型相比的性能

Orca-Math的核心方法

Orca-Math基于Mistral-7B,一个7亿参数的SLM。其核心方法包括:

高质量合成数据集:使用200K个合成数学问题,这些问题是通过多代理设置创建的,代理之间协作生成数据。

迭代学习技术:允许SLM通过解决数学问题、接收解决方案的反馈,并从偏好对中学习,这些偏好对结合了SLM的解决方案和反馈。

数据集构建

数据集的构建目标是创造一个包含简单和复杂问题的多样化小学数学文字题集合。研究团队首先从现有的开源数据集中收集了小学数学文字题样本。这些数据集包括 NumGLUE、AddSub、ALGES、ASDiv、DRAW、GSM8k、MATHQA、MultiArith、SingeOP、SingleEQ 和 SVAMP,共计 36,217 个问题。他们使用 Lila 基准来收集这些数据集,并特别从 Lila 的训练和验证分割中提取问题以构建种子集。

为了扩展种子集,团队创建了一个名为 “问我任何问题” 的代理,用于从种子集中的每个问题生成多个文字题。这个过程是通过以下步骤完成的:

  1. 将原始问题转换成陈述句。
  2. 对陈述句中的每个数字,创建一个新的文字题。

例如,给定问题:“Natalia 在四月向她的 48 个朋友卖了夹子,然后在五月卖了是四月一半数量的夹子。Natalia 在四月和五月总共卖了多少夹子?”答案是 72 个夹子。将问题转换成陈述句后,围绕数字 48 创建了新的问题:“Natalia 在四月向一些朋友卖了夹子,然后在五月卖了和四月一样多一半的夹子。Natalia 在四月和五月总共卖了 72 个夹子。她在四月卖了多少夹子?”以此类推,生成多个类似问题。

为了进一步扩展种子集并增加问题的难度,研究团队引入了两个新的代理:“建议者”和“编辑”。这两个代理协同工作,通过迭代过程来修改现有问题,使其变得更具挑战性。工作流程如:

  1. “建议者”检查特定问题并提出增加复杂性的几种方法,但不会实际创建问题。
  2. “编辑”根据“建议者”的建议,对原始问题进行修改,生成一个更新的、更具挑战性的问题。

例如,原始问题是关于 Joanne 每小时从购物中心的喷泉中收集硬币。通过迭代过程,问题变得越来越复杂,涉及到更多的喷泉、更多的变量和更复杂的交易。

这个过程可以进行多轮迭代,每轮都使问题的复杂性增加。使用 AutoGen 框架来实现多代理工作流程。最终,通过这个过程,团队收集了 37,157 个问题。

除了上述过程,研究团队还包括了来自 DMath 数据集的 6,216 个问题。这些问题是 DMath 训练集中的一个子集,其中 GPT4-Turbo 计算出的解决方案与精确的标准答案一致。

训练

研究团队对 Mistral-7B 模型进行了微调,使用的是 Orca-Math-200K 数据集。数据以一种指令格式呈现,即用户提出问题,助手给出答案。这里的损失仅计算在答案令牌上,使用了恒定的学习率 1×10−61×10−6。每个设备的批量大小被设置为 3,训练在一个包含八个 A100 节点的集群上进行,每个节点有八个 GPU,训练周期为一个 epoch。

为了为每个问题生成额外的正面和负面解决方案,团队从迭代 #1 中微调后的模型中采样了四个响应。具体来说,他们使用 top_p = 0.95 和temperature = 0.7。这一过程产生了一个数据集,其中每个问题都有 GPT4-Turbo 生成的一个解决方案和四个学生生成的解决方案。然后,使用 GPT4-Based-Exact-Match 提出的提示(见第 4 节详情)来评估教师(GPT4-Turbo)的答案与学生答案之间的一致性。对于所有学生生成的答案与教师答案不一致的解决方案,将其标记为负面;否则,将解决方案标记为正面。接着他们按照以下方式构建偏好数据集:

  • 对于每个问题 ,构建 ,即 的所有正面解决方案的集合。教师的解决方案被视为正面,因此这个集合至少包含一个元素。

  • 对于每个问题,还构建 ​,即的所有负 面解决方案的集合。如果所有四个响应都与教师的解决方案一致,那么这个集合可能是空的。实际上,大约有 80k 个问题的情况就是这样。对于这些情况,他们从中随机抽取一个响应,用于 4 个不同的 ​,其中。注意,在这种特殊情况下

  • 作为围绕的偏好数据集。最终的偏好数据集是通过取所有 在训练数据集中 的并集来创建的。

设M2 表示在迭代 #2 构建的数据集上使用 KTO [10] 训练的模型。他们复制了迭代 #2 的数据集构建过程;但是,使用M2 生成四个响应,而不是迭代 #1 中微调的模型。

为了从正面和负面反馈中学习,团队评估了两种算法的性能:直接偏好优化(DPO)和 Kahneman-Tversky 优化(KTO)。DPO 是一种简单且流行的方法,用于有效微调语言模型以符合偏好。此外,他们还探索了 KTO 的能力,KTO 的特点是只需要一个二元的“是”或“否”响应来评估输出的质量。

评估

Orca-Math 项目中的评估方法是使用精确匹配(Exact Match)作为指标。具体来说,给定模型生成的答案,研究者使用 GPT4 提取最终简短的答案,并将其与标准答案进行匹配,这种方法称为基于 GPT4 的精确匹配(GPT4-based-Exact-Match)。评估过程中使用的提示模板如下:

SYSTEM 作为一名专业的数学老师,你的任务是评估一个学生对文字题的答案。这个问题附有问题制定者提供的正确解决方案。重要的是要记住,解决文字题可能有多种方法,所以学生的解题步骤可能不总是与问题制定者的答案一致。然而,最终答案,通常是一个数字,应该是唯一的,并且应该与问题制定者的答案相匹配。你的任务包括分析学生的解决方案,识别任何错误,并确定是否可以修改答案以纠正错误。如果学生的答案无法修正,请考虑创建练习题以帮助他们提高理解。 使用以下格式: 错误分析:用一句话,从问题制定者的答案中提取最终答案,并与学生的答案进行比较。它们是否匹配? 最终裁决:正确/不正确

例如,对于 Billy 帮助人们报税的问题,问题制定者和学生的答案都是 Billy 总共帮助了 240 人。助理(评估者)的错误分析确认了学生的答案与问题制定者的答案相匹配,最终裁决是正确的。

表 2 展示了在包含 1319 个文字题的 GSM8k 测试集上,不同训练过程的性能。Mistral-7B 模型经过最多三次迭代的微调。在第一次迭代中,使用监督式微调获得 M1。在第二次迭代中,比较了 SFT、DPO 和 KTO,其中 KTO 训练的模型在这组中表现更好,称之为 M2,并用 M2 生成第三次迭代的数据集。在第三次迭代中,比较了以 M2 为起点的 DPO 和 KTO。还将这些与 Orca-Math-200K 数据集上的三次 SFT 训练进行了比较。

对于所有 SFT 训练,采用恒定的学习率 1×10−61×10−6,每设备批量大小设为 3,训练周期设为 1。对于 DPO 和 KTO 训练任务,设置 beta 为 0.3,每设备批量大小为 3,梯度累积步数为 11,训练周期为 1。在第二次迭代中,DPO 和 KTO 训练采用恒定学习率 1×10−61×10−6,在第三次迭代中采用 1×10−71×10−7。

消融研究部分探讨了模型生成的正面样本的影响,通过限制只包含教师生成的解决方案,即在创建第二次迭代的数据集时,移除了模型生成的任何 ​。表 3 展示了用这个数据集训练 M1 一周期的 DPO 和 KTO 的结果。不论训练算法如何,性能都有显著下降。

表 5 展示了 Orca-Math 在其他几个文字题数据集上的性能。为了便于评估,选择的基准数据集中每个问题的答案都是单个数字。基准的测试集是从 Lila 获取的。采用基于 GPT4 的精确匹配指标,并且模型响应是使用贪婪解码生成的。

尽管在训练或作为合成问题生成的种子时从未使用过 GSM8K 或其他数据集的测试分割,但研究者采取了以下方法来检测潜在的文本污染:

首先对文本进行预处理,包括将所有字符转换为小写,去除标点符号,将文本分词成单独的单词,并去除常见的英文停用词,以确保数据的一致性。

然后使用词频-逆文档频率(TF-IDF)方法向量化文本语料库,并确定测试集和训练集之间的余弦相似度,从中为每个测试查询选择 top-k(k=10)最相似的问题。

最后,通过计算测试问题与其对应的训练集匹配项之间最高 n-gram 重叠次数,并使用 Jaccard 相似度来评估文本污染的程度。为了进行严格的污染检查,将 n 设置为 1。重要的是要注意,使用 Jaccard 相似度测量的 n-gram 重叠是 n 的非增函数。

执行算法后,发现有 8 个测试问题显示出显著的 n-gram 重叠,这表明根据定义的阈值,测试集中的文本污染可以忽略不计。当将训练集限制为仅包含种子问题时,显示出显著 n-gram 重叠的测试问题数量为 7。注意,对于 n ≥ 2,显示出显著 n-gram 重叠的测试问题数量为零。

Orca-Math的研究为SLMs在数学推理能力上的显著提升提供了有力证据。通过迭代学习技术和利用正面及负面信号,Orca-Math成功超越了之前认为的80%的准确率障碍。这一成果不仅凸显了SLM性能的潜在重大改进,也强调了创新学习策略和数据集生成方法在推动SLMs发展中的重要性。

论文链接:https://arxiv.org/abs/2402.14830

  • 8
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

人工智能大模型讲师培训咨询叶梓

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值