通过知识蒸馏提升大模型训练效率

 人工智能咨询培训老师叶梓 转载标明出处

随着模型规模的不断扩大,如GPT-4这样的模型拥有约1.7万亿参数,其预训练所需的巨大能源和计算资源引发了对可持续发展AI解决方案的迫切需求。麦吉尔大学的研究团队介绍了一种创新的方法来解决与LLMs预训练相关的效率问题,即通过知识蒸馏实现跨架构的知识转移。研究团队提出了一种名为Hyena的机制,该机制通过替代变换器模型中的注意力头,提供了一种成本效益更高的替代传统预训练的方法。与传统的压缩方法不同,该技术不仅提高了推理速度,而且在准确性和效率方面都超越了预训练。

方法

Hyena算子是本文的核心创新之一,由Poli等人在2023年提出。它旨在作为次线性(subquadratic)替代方案,以替换变换器中的注意力(attention)操作。与H3等其他状态空间模型不同,Hyena直接对滤波器进行参数化,这相当于线性时不变(LTI)系统的脉冲响应。

具体来说,Hyena算子首先对时间索引应用位置嵌入,其中df​是嵌入维度。然后,通过前馈神经网络(FFN):,其中dm​是模型的维度,并将结果乘以一个窗口函数以获得滤波器h[n]。数学表达式为:

Hyena算子​使用这样的滤波器ℎh来聚合长上下文窗口的上下文,并通过对乘法门控机制引入非线性。首先通过投影操作P(x,θ)获得三个投影q,k,v,该操作由参数θ控制。投影操作包括一个线性投影​,然后是一个短的深度卷积,使用短滤波器​进行局部信息交换。然后使用逐元素乘法,接着是卷积和第二个逐元素乘法来计算Hyena算子的输出:其中∗表示卷积操作,⊙表示逐元素乘法。注意,通过使用不同数量的投影,可以进一步泛化该算子。

在进行实验时,研究团队选择了70M参数版本的GPT-NeoX模型,这是一个仅解码器的变换器模型,其架构与GPT-3非常相似,但存在一些关键差异:

  • 传统GPT模型中的位置嵌入被旋转位置嵌入(RoPE)所替代,它使用旋转矩阵对token的位置信息进行编码。
  • 通常在传统GPT模型中串行发现的注意力和前馈层在GPT-NeoX中为了效率而并行计算。
  • 所有的前馈层都是密集的,与GPT-3中密集和稀疏层的交替不同。

值得注意的是,GPT-NeoX的架构与GPT-J非常相似。图1展示了模型架构的详细图示,其中包括:

  • A) GPT NEO X层架构:70M GPT NEO X中的6层堆叠注意力和多层感知机(MLPs)。
  • B) 使用Hyena算子替换注意力头的Hyena-Distilled NEO GPT X层架构,用于蒸馏任务。
  • C) 来自Vaswani等人(2017)的注意力算子的视觉表示。
  • D) 来自Poli等人(2023)的Hyena算子的视觉表示。

本文的目标是将注意力机制替换为Hyena机制。由于Hyena算子已经保留了其输入token的位置信息,因此Hyena版本的模型不包括旋转位置嵌入。研究使用了Biderman等人在2023年实现的Pythia模型,并在开源的Pile数据集上进行了训练。

研究采用了逐步知识转移(Progressive Knowledge Transfer)的方法来逐步训练学生模型。对于每一层,首先在教师模型上对一个token数据集X进行推理,以获得一个蒸馏数据集,其中x是token索引序列,​是教师模型在第i层的输出。然后,最小化均方误差损失,使用​——学生模型在第i层的输出,一次训练一层。对于最后一层,可以通过在文本数据上进行无监督训练来额外微调模型:

所有语言建模实验都使用了OpenWebText数据集。通过从OpenWebText中随机抽取200万个示例来获得一个标记化的预训练数据集,每个预训练示例的上下文长度为1024。数据集被分为训练集和验证集,其中0.1%被保留用于验证。对于蒸馏实验,从训练集中采样了4000万个token来获得用于训练每层的蒸馏数据集。

所有实验都使用了与70M教师模型相同的6层GPTNeoX风格架构。研究者首先基于Pythia和Hyena模型的超参数,从头开始对模型进行预训练,使用了10亿个token。研究者定义预训练为从随机初始化的模型开始,在文本数据上进行无监督学习的过程。同样,研究者定义无监督微调(CE-tinune)为从模型检查点开始,在文本数据上进行无监督学习的过程。在预训练阶段,研究者实现了一个线性预热,跨越300个训练步骤,然后使用余弦衰减在2000次迭代中降低学习率。这种衰减持续到达到最大学习率的10%,此时学习率保持不变。类似地,在蒸馏过程中,研究者在总训练步骤的2.5%上实施线性预热,然后在整个步骤集上衰减,直到达到最大学习率的10%。研究者尝试只进行蒸馏(MSE)以及微调(CE-tinune)。所有实验都设计在RTX 3090上运行5小时。

在Pythia模型的解码器层上进行渐进式知识转移的图示

结果与分析

困惑度(Perplexity)作为衡量语言模型性能的关键指标,用于评估模型对真实数据分布的预测准确性。研究者使用了OpenWebText和WikiText数据集来计算所有模型的困惑度得分。他们采用了与预训练数据集相同的验证集来计算得分,并且所有模型的困惑度得分都是在1024个token的上下文长度下获得的。

表1展示了四种不同模型的困惑度得分:

  • PYTHIA-70M (TEACHER): 教师模型,使用传统的注意力机制,其在WikiText和OpenWebText上的困惑度得分分别为51.4和35.3。
  • PRE-TRAINED: 直接预训练的Hyena模型,得分较高,分别为230和64.9。
  • MSE: 使用均方误差(MSE)损失进行蒸馏后的Hyena学生模型,得分有所下降,分别为155.8和63.5。
  • CE FINE-TUNE: 在蒸馏后进行交叉熵(CE)微调的Hyena学生模型,其困惑度得分进一步降低,分别为121.2和49.6。

这些结果表明,经过蒸馏和微调的学生模型在语言建模任务上的性能有了显著提升,尤其是在OpenWebText数据集上,其困惑度得分接近教师模型。

研究者进一步在三个模型上应用了一系列自然语言任务,以评估它们在不同任务上的表现:

  1. 使用Hyena替代注意力机制的GPT模型。
  2. 使用传统注意力机制的Pythia 70M教师模型。
  3. 使用Hyena并通过联合知识转移(JKT)进行蒸馏的Pythia 70M学生模型。

他们使用了语言模型评估工具(lm eval)对这三个模型在多个不同的自然语言任务上进行了基准测试。测试结果如表2所示,所有结果都是在32位浮点精度下测量的,以确保可重复性并最小化由于低精度引起的机器误差。

表2中列出了不同任务的准确率(ACC)和标准偏差,包括ARC挑战、ARC简单、LOGIQA、PIQA、SCIQ、WINOGRANDE和WSC任务。从表中可以看出,使用Hyena的学生模型在某些任务上的表现略低于教师模型,但在Arc挑战和WSC任务上,学生模型的表现则略高于或显著高于其他两个模型。

表1的实验结果表明,在相同的GPU小时预算内,逐步知识转移与传统的预训练方法相比,在模型性能上具有优势。本方法在没有额外无监督学习的情况下取得了更好的性能,这表明了逐步知识转移策略的效率。

另外研究结果揭示了蒸馏作为无监督学习前的一个初始化步骤的潜力。这种方法在与传统预训练和纯知识转移相同的训练成本下提供了提高的性能。这表明知识蒸馏方法不仅提供了改进的初始性能,而且还允许在不增加额外训练费用的情况下进行额外的优化。

对结果的进一步检查强调了知识蒸馏对模型泛化的重大影响。的确,使用蒸馏在WikiText困惑度得分上的提高强调了本方法在增强模型用教师模型的知识对未见数据进行外推的能力方面的有效性。这为知识蒸馏在机器学习场景中的更广泛适用性和鲁棒性提供了宝贵的见解,特别是与传统的预训练策略相比。

表2表明,使用Hyena预训练的GPT模型通常具有与使用Hyena的Pythia 70M模型相似但略低的准确率。这些结果表明,使用Hyena的LLM通常能够像基于注意力的LLM模型一样表现良好,尽管基于Hyena的模型通常具有略低的测量性能。学生Pythia 70M JKT模型通常比预训练的基于注意力的Pythia 70M模型表现略差,尽管模型性能通常在相似的范围内,除了Sciq任务,学生模型的准确率明显低于GPT Hyena和教师模型。然而,在Arc挑战和Wsc任务中,Pythia 70M学生模型略微优于并显著优于其他两个模型。

结果表明,学生Hyena模型上的联合知识转移通常保留了其教师模型的语言能力,并且学生Hyena模型在某些情况下可以优于其教师模型。因为Hyena在直接比较时比注意力更有计算效率,并且因为联合知识转移可能比传统预训练更有计算效率,结果表明Hyena学生模型上的联合知识转移提供了一种计算效率高的替代传统基于注意力的LLMs预训练的方法。

论文链接:https://arxiv.org/abs/2401.17574

项目链接:

  • Pythia:本文中使用的模型实现之一。
  • The Pile:本文中用于训练的数据集之一。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值