LLM.02 Mixed Distillation Helps Smaller Language Model Better Reasoning

标题: Mixed Distillation Helps Smaller Language Model Better Reasoning


在这里插入图片描述
这篇论文介绍了一种名为“混合蒸馏”(Mixed Distillation, MD)的框架,旨在通过结合“思考过程链”(Chain of Thought, CoT)和“程序思考路径”(Program of Thought, PoT)两种能力,来提升小型语言模型在推理任务上的表现。论文指出,尽管大型语言模型(Large Language Models, LLMs)在自然语言处理(NLP)任务上取得了显著成就,但它们的高计算和内存需求限制了实际应用。因此,研究者探索如何将LLMs的知识迁移到小型模型中,以在保持性能的同时降低资源消耗。


⏲️ 年份: 2024
👀期刊/影响因子:
📚 数字对象唯一标识符DOl:
🤵 作者: Li Chenglin,Chen Qianglong,Li Liangyue,Wang Caiyu,Li Yicheng,Chen Zulong,Zhang Yin
论文链接:https://download.csdn.net/download/klhhk/89490144


👁️‍🗨️摘要:

主要贡献包括:

  1. 混合蒸馏框架:MD框架创新性地融合了CoT和PoT两种提示技术。CoT鼓励LLMs生成中间推理步骤的自然语言描述,而PoT则促使LLMs生成可执行的Python代码作为中间步骤,从而增强模型的逻辑推理能力。这些能力随后被蒸馏到小型模型中,使得小模型能够进行更复杂的多路径推理。
  2. 实验验证:研究通过一系列实验展示了MD的有效性,不仅提升了基于LLaMA2-7B和CodeLlama-7B等小型模型在SVAMP基准测试上的准确率,分别达到了84.5%和85.5%,还超过了GPT-3.5-Turbo的表现。这证明了MD能显著增强模型的单路径和多路径推理能力,并在推理任务的准确性和泛化能力上超越了单独蒸馏模型的综合性能。
  3. 方法细节:MD框架首先利用精心设计的提示模板从无标签数据集中提取LLMs的推理路径,包括自然语言形式的CoT和Python代码形式的PoT。然后,过滤掉不可执行或无答案的路径,并将筛选后的混合思考路径用于训练小型任务特定模型。在训练过程中,同时利用CoT和PoT作为监督信号,帮助小型模型学习多类型的数据分布,进而提升其多路径推理能力。最后,在推断阶段,小型模型利用这两种能力通过自我一致性投票来确定最终答案。
  4. 与现有工作的对比:MD框架弥补了先前工作在探索PoT和CoT如何协同改进小型模型推理能力方面的空白,尤其是通过结合两种提示技术,实现了对小型模型性能的有效提升。而传统知识蒸馏方法通常仅依赖于CoT作为监督信号,忽略了PoT的潜力。

综上所述,这篇论文通过引入混合蒸馏策略,为提升小型语言模型的推理能力提供了一个新的有效途径,对于推动NLP领域中模型轻量化和高效应用具有重要意义。

👀研究背景和研究问题:

研究背景

随着近年来大型语言模型(LLMs)在自然语言处理(NLP)领域的显著进步,它们展现出了强大的性能,特别是在解释预测、生成中间推理步骤方面的能力。然而,部署这些大型模型面临严峻挑战,主要是因为它们在实际应用中对计算资源和内存有着极高的需求。因此,如何在不牺牲性能的前提下缩小模型规模,成为了研究者关注的重点。近期研究致力于通过知识蒸馏技术,将大型模型的知识转移到小型模型中,以期在保持性能的同时降低资源消耗,但小型模型在需要推理能力的任务上仍难以匹敌大型模型的表现。

在此背景下,本文作者们观察到,尽管已有不少工作聚焦于通过单一路径(如链式思考CoT或程序思考PoT)的知识蒸馏来提升小模型性能,但这些方法往往未能充分利用大型模型在多路径推理上的潜力。因此,开发一种新的蒸馏框架,既能融合CoT和PoT的各自优势,又能有效提升小型模型的单路径及多路径推理能力,成为了一个亟待解决的问题。

研究问题

  1. 如何有效地结合CoT与PoT能力:研究者关注如何设计一种混合蒸馏(Mixed Distillation, MD)框架,将大型语言模型中的链式思考(Chain of Thought, CoT)和程序思考(Program of Thought, PoT)两种推理模式的能力综合起来,并成功地蒸馏到小型模型中。这要求探索如何在蒸馏过程中同时保留并融合这两种推理方式的特点。
  2. 提升小模型的推理能力:研究的核心问题是如何通过混合蒸馏框架显著增强小型语言模型在多种任务上的单路径和多路径推理能力,特别是那些需要复杂逻辑推理的任务,比如数学问题解答、常识推理等。
  3. 超越现有模型性能:研究旨在通过实验验证混合蒸馏框架下训练的小型模型(如LLaMA2-7B和CodeLlama-7B)在推理基准测试(如SVAMP)上的表现,能否超越传统的蒸馏方法和封闭源代码的大型模型(如GPT-4和GPT-3.5-Turbo),并且在准确率和泛化能力上取得显著提升。

综上所述,该论文的研究背景是基于大型语言模型的推理能力及其在实际应用中的局限性,而研究问题则集中于设计和验证一种新型混合蒸馏方法,以此提升小型模型的推理性能,并在一系列基准测试中展示其实效性。

🎨研究方法及改进:

研究方法

1. 思路提取(Thoughts Extraction)

首先,研究团队利用多轮提示技术从LLMs中提取思路。对于CoT,通过精心设计的自然语言提示,促使LLMs生成解决问题的中间推理步骤。而PoT则通过设计特定的Python程序执行相关的提示,激发模型生成可执行代码片段作为中间推理过程。这种方法不仅限于自然语言描述,也包括了形式化的编程逻辑,从而丰富了模型的推理路径。

2. 混合思路蒸馏(Mixed Thoughts Distillation)

接下来,将提取的CoT和PoT思路结合用于训练小型任务特定模型。通过将这些混合思路作为监督信号,模型学习如何模仿大型模型的推理过程。此过程不仅强化了模型的自然语言理解,还提高了其形式化逻辑处理能力。研究中采用了一种标准的特定任务学习范式,即通过最小化预测标签与目标标签之间的交叉熵损失来微调模型。

3. 自我一致性投票(Self-Consistent Voting)

在推理阶段,小型模型利用CoT和PoT的双重能力进行多路径推理,并通过自我一致性投票策略确定最终答案。这意味着模型会考虑多个推理路径的输出,并基于它们的一致性来决定最可能正确的答案,从而提高决策的准确性。

论文中提及的关键公式及相关概念如下:
  1. 损失函数(Loss Function)

    1. 基础损失定义

      基础的损失函数(记为 L L L)衡量了模型预测输出( y i ^ ​ \hat{y_i}​ yi^)与目标输出(可以是人工标注的 y i ​ y_i​ yi 或是大模型预测的 y i ^ ​ \hat{y_i}​ yi^)之间的不一致程度,常使用交叉熵损失( ℓ ℓ )来计算。对于一个包含 N N N 个样本的数据集,损失函数定义为: L = 1 N ​ ∑ i = 1 N ​ ℓ ( f ( x i ​ ) , y ​ i ^ ​ ) L=\frac{1}{N}​∑_{i=1}^N​ℓ(f(x_i​),\hat{y_​i}​) L=N1i=1N(f(xi),yi^) 其中, f ( x i ​ ) f(x_i​) f(xi) 表示模型对于输入 x i ​ x_i​ xi 的预测输出, y i ^ \hat{y_i} yi^ 是对应的期望输出, ℓ ℓ 是损失函数, N N N 是样本总数。

    2. 多任务学习损失

      在MD框架中,为了同时利用CoT和PoT的能力,损失函数被设计为两部分的加权和,分别对应于CoT推理路径和PoT推理路径的损失: $ L = (1 - \lambda)L_{\text{path_CoT}} + \lambda L_{\text{path_PoT}}$ 这里, λ λ λ 是一个权重参数,用于平衡CoT损失( L path_CoT ​ L_{\text{path\_CoT}}​ Lpath_CoT)和PoT损失( L path_PoT ​ L_{\text{path\_PoT}}​ Lpath_PoT)的重要性,默认设置为0.5,以实现两者的均衡考虑。

    3. 单一路径损失

      对于每个路径,定义了单独的损失 L p a t h ​ L_{path​} Lpath,它既包括生成推理路径的损失,也包括预测标签的损失: $ L_{\text{path}} = \frac{1}{N} \sum_{i=1}^{N} \ell(f(x_i), \hat{r}_i + \hat{y}_i)$ 其中, r i ^ ​ \hat{r_i}​ ri^ 是由大语言模型生成的推理路径(可以是自然语言形式的CoT路径或代码形式的PoT路径), y i ^ ​ \hat{y_i}​ yi^ 是与之对应的预测标签。这意味着模型不仅要学习预测正确答案,还要学会生成合理的推理路径。

    4. 输入与任务提示

      在实际操作中,输入 x i ​ x_i​ xi 被嵌入到特定的任务提示(如 “Let’s think step by step” 或 “Let’s break down the code step by step”)中,以引导模型生成特定类型的推理路径。这些提示被用作生成CoT和PoT路径的引导,增强了模型在特定任务上的推理能力。

    5. 推理与投票机制

      在推理阶段,输入 x i ​ x_i​ xi 会与相应的提示语句结合,分别生成CoT和PoT的推理路径,然后通过多次采样生成答案列表 A C o T ​ A_{CoT}​ ACoT A P o T ​ A_{PoT}​ APoT。最终预测 P f i n a l ​ P_{final}​ Pfinal 通过一个投票函数 V V V 从这两个列表的合并中选出,该函数通常会选择出现频率最高的答案作为最终输出。

      通过这样的损失函数设计,MD框架不仅能够提升小模型在单一任务上的推理能力,还能促进模型在面对需要多路径推理的任务时展现出更强的通用性和准确性。

  2. 最终预测(Final Prediction)

最终预测的生成,即:

P f i n a l ​ = V ( c o n c a t ( A C o T ​ , A P o T ​ ) ) P_{final}​=V(concat(A_{CoT}​,A_{PoT}​)) Pfinal=V(concat(ACoT,APoT))

这个公式概括了混合蒸馏框架中推理过程的最后一步,即如何综合链式思考(CoT)和程序思考(PoT)的输出来得到最终答案。下面是对这个公式含义的详细解读和其背后逻辑的推导说明:

公式解释

  • P f i n a l ​ P_{final​} Pfinal: 表示最终预测的结果。
  • V ( ⋅ ) V(⋅) V(): 是一个投票函数,作用是从给定的选项列表中选择出现频率最高的项。这意味着它负责从合并的CoT和PoT答案集中确定哪个答案最可能是正确的。
  • c o n c a t ( A C o T , A P o T ) concat(A_{CoT},A_{PoT}) concat(ACoT,APoT): 表示将两个答案列表连接起来形成一个新的列表。其中, A C o T = a 1 , a 2 , . . . , a n A_{CoT}={a1,a2,...,an} ACoT=a1,a2,...,an代表通过链式思考路径获得的答案列表,而 A P o T ​ = b 1 ​ , b 2 ​ , . . . , b n ​ A_{PoT}​={b1​,b2​,...,bn​} APoT=b1​,b2​,...,bn代表通过程序思考路径获得的答案列表。每个列表中的元素是通过多次采样得到的潜在正确答案。

推导逻辑

  1. 思想路径生成: 首先,通过不同的提示策略(“Let’s think step by step” 用于CoT路径,“Let’s break down the code step by step” 用于PoT路径)激发大型语言模型(LLM)生成推理路径。这些路径被用来指导小型模型学习如何逐步推理问题。
  2. 答案采样: 对于给定的输入问题,小型模型在推理时,首先使用CoT提示产生一系列自然语言的推理步骤,然后执行这些步骤以生成答案集合 A C o T ​ A_{CoT}​ ACoT。类似地,通过PoT提示,模型生成可执行的代码片段,利用Python执行器运行这些代码来获取另一组答案集合 A P o T ​ A_{PoT}​ APoT。每个集合中的答案是通过多次独立的采样迭代得到的。
  3. 答案合并与投票: 最后,将这两个独立推理路径得到的答案集合合并成一个大的列表,然后应用投票函数 V V V来确定最终答案。投票函数的工作原理是统计每个答案出现的次数,并选择出现频率最高的答案作为模型的最终预测。这样可以结合两种思考方式的优点,提高预测的准确性和可靠性。

改进之处

  • 多路径推理能力的提升:与仅依赖单一路径(CoT或PoT)蒸馏的先前工作相比,MD框架强调了结合两种路径的优势,显著增强了模型在不同任务上的单路径和多路径推理能力。
  • 兼顾自然语言与编程逻辑:通过同时蒸馏CoT和PoT,模型不仅能够理解和生成自然语言的推理步骤,还能生成可执行代码,提升了在需要数学和逻辑推理任务上的表现。
  • 更广泛的适用性和准确性:实验结果显示,使用MD训练的LLaMA2-7B和CodeLlama-7B模型在SVAMP基准测试上分别实现了84.5%和85.5%的准确率,优于GPT-3.5-Turbo的性能。这表明MD框架在提升模型的推理准确性和泛化能力方面取得了实质性的进展。
  • 解决过分布数据问题:MD框架还被证明在处理超出训练数据分布的任务时,通过引入多路径推理,能有效提升模型的学习能力和适应性,进一步巩固了其在推理任务上的有效性。

‼️实验对比结果:

单一路径推理能力提升
  • PoT蒸馏:表与传统的CoT蒸馏和标签微调相比,采用PoT蒸馏的模型在多个任务上表现更优。例如,T5-large在SVAMP上的准确率提高了61.2%,LLaMA2在GSM8K和ASDIV上的准确率分别提高了33.2%和14.9%。这证明PoT作为监督信号在特定任务上优于CoT。
混合蒸馏效果
  • 多路径推理能力增强:MD框架不仅提升了单路径推理能力,还在多路径推理上显示了显著优势。特别是当使用MD框架时,LLaMA2-7B和CodeLlama-7B在SVAMP上的准确率分别达到了84.5%和85.5%,超过了GPT-3.5-Turbo的性能,分别高出2.5%和3.5%。此外,表中还展示了MD框架在其他任务上的提升,如CodeLlama-7B在ASDIV上达到的准确率为53.2%,相比未使用MD提高了19.2%。
  • 自我一致性投票机制:通过设置默认的总采样路径数为20(Wang et al., 2022),MD框架确保了最终预测结果是基于每个独立路径获得答案的投票机制决定的,增强了结果的稳定性和准确性。
综合性能对比
  • 与基线模型的比较:MD框架下的模型在各项指标上均超越了诸如单路径CoT蒸馏、PoT蒸馏及未使用MD方法的模型。例如,T5-Large-MD在采用CoT w/ PoT(即结合CoT和PoT的混合蒸馏)后,准确率从基础的54.3%提升至76.0%,增幅显著。
  • 泛化能力验证:研究还验证了MD框架的泛化性,表明它不仅在特定任务上表现优异,而且在跨任务上也能维持高水平的推理能力。

📚数据集以及评价指标:

数据集

  1. SVAMP (Patel et al., 2021): 专注于数学推理任务,特别侧重于解决代数方程和算术问题,要求模型具备较强的数学逻辑推理能力。
  2. GSM8K (Cobbe et al., 2021): 包含了大量的数学应用题,要求模型能够理解自然语言问题并执行多步骤推理来找到答案,是检验模型数学推理能力的重要基准。
  3. ASDIV (Miao et al., 2021): 专门设计用于评估模型在处理算术、几何、统计和数据分析等多类型数学问题上的能力,是一个综合性的数学推理数据集。
  4. StrategyQA (Geva et al., 2021): 该数据集用于评估模型的常识推理能力,问题涉及需要隐含推理策略的情境,要求模型具备逻辑推理和世界知识理解。

评价指标

评价指标主要集中在准确率(Accuracy),这是衡量模型预测结果与真实答案匹配程度的最直观指标。论文中通过比较不同模型在各个数据集上的准确率变化来评估MD框架的效果。例如,表1和表2展示了在SVAMP、GSM8K、ASDIV等数据集上,采用MD框架的模型相比于仅使用CoT或PoT蒸馏的模型,以及未经过MD处理的模型,在准确率上有显著提升。具体来说,如LLaMA2-7B和CodeLlama-7B模型在使用MD后,在SVAMP上的准确率分别达到了84.5%和85.5%,相较于仅使用CoT或PoT蒸馏的方法,准确率提高了46.5%和46.5%,甚至超越了GPT-3.5-Turbo的表现。


💯论文创新点:

  1. 混合蒸馏框架(Mixed Distillation Framework):提出了一个新颖的蒸馏方法,首次将链式思考(CoT)和程序思考(PoT)两种不同的推理路径相结合,通过一种综合性的蒸馏过程,有效提升了小型语言模型的推理能力。这一框架打破了以往仅依赖单一推理路径进行知识转移的限制,创新性地融合了自然语言推理和形式化程序逻辑的优势。
  2. 程序思考(Program of Thought, PoT)引入:在蒸馏过程中创新性地引入了程序思考的概念,不仅限于自然语言的链式思考路径,还通过生成可执行代码片段作为推理路径,拓宽了模型学习和推理的范围。PoT的引入使得模型能够处理更广泛的逻辑和数学问题,增强了模型的计算和逻辑推理能力。
  3. 多路径推理与自我一致性投票:通过结合CoT和PoT的输出,论文提出了多路径推理策略,并采用自我一致性投票机制来确定最终答案。这种机制提高了模型在复杂问题上的鲁棒性和准确性,是传统单路径推理的一个重要创新升级。
  4. 高效模型微调与泛化能力提升:论文中展示了如何利用QLORA方法在有限资源下(如单个GPU)高效地微调模型,同时保持了模型的泛化能力。这不仅降低了资源门槛,也证明了混合蒸馏框架在实际应用中的可行性与有效性。
  5. 超越封闭源模型的性能:实验结果表明,经过混合蒸馏训练的小型模型(如LLaMA2-7B和CodeLlama-7B)在多个基准测试(如SVAMP)上的表现,超越了一些封闭源的大型模型(如GPT-3.5-Turbo)。这不仅是技术上的突破,也为开放研究社区提供了与封闭源模型竞争的新途径。
  6. 理论与实践的紧密结合:论文不仅在理论上构建了混合蒸馏的框架,还在实践中进行了详尽的实验验证,包括不同数据集上的性能对比、训练数据量影响分析等,确保了理论创新的实际价值和可应用性。

❓启发与思考:

🍞不足及可改进的点:

  • 17
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KeSprite

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值