LoRA结合上MoE的路由功能,让LoRA的参数增量不再是静态的死板一块,而是可以根据不同任务的输入来动态生成,有效解决了SFT大量训练数据和世界知识遗忘的冲突。最近组里同学在尝试实现LoRAMoE,意在解决大模型微调后遗忘世界知识的问题。参考的是复旦23年年底的这篇论文:"LoRAMoE: Revolutionizing Mixture of Experts for Maintaining World Knowledge in Language Model Alignment"[1],将LoRA和MoE方法做了组合,有效解决了上述问题,本文详细讲解一下这篇论文。

1. 背景

大模型经过大量语料的无监督预训练后,得到所谓的基座模型,这时候通常还不能很好地完成下游任务,需要经过有监督的微调(SFT)后才能和人类指令对齐,释放其全部潜力。

一般来说,SFT的训练数据不需要太多,但当下游任务增多或者需要强化特定任务的性能时,增加SFT训练数据还是有必要的。如下图的左侧部分,当SFT数据从100K提升到3M时,大部分任务的性能显著增强。

但随着SFT数据的大规模增加,新的问题出现了:如下图的右侧部分所示,在某些评测数据集上性能显著下降,与之相伴的是大模型的参数变化量剧增(见红色线段)。这些数据集属于闭卷问答任务(Closed-Book Question Answering,简称CBQA),即只给大模型输入问题,大模型主要依靠在预训练过程中习得的世界知识来给出答案。

这里补充一下:像TriviaQA、Natural Questions这类数据其实是包含问题相关上下文的,也就是说如果用作开卷问答任务,则输入不仅包括问题还包括上下文,大模型可以从上下文中总结出答案;但如果用作闭卷问答任务,则输入中不提供上下文。论文[2]中6.1节有提到。

LoRAMoE_数据

我们有理由怀疑CBQA的性能下降与大模型世界知识的崩溃有关,下面将通过实验证明这一点。首先验证CBQA的推理依赖大模型的世界知识,其次证明CBQA数据集上性能的大幅下降归因于大规模微调会显着改变模型参数,导致世界知识的破坏,即发生知识遗忘。

2. 大规模微调导致世界知识破坏

下面详细介绍一下实验过程,即做实验证明大规模SFT导致大模型的世界知识严重受损,引起知识遗忘。

2.1 实验设计

数据集

准备了7种任务的数据集,分别是CBQA(闭卷问答)、coreference resolution(指代消解)、NLI(自然语言推理)、summarization(文本摘要)、multi-lingual translation(多语言翻译)、reading comprehension(阅读理解)、text classification(文本分类)。具体数据集见下图:

LoRAMoE_权重_02

基座模型

采用LLaMA-2-7B作为基座模型,属于在学术界非常流行的LLM之一。

评估

将任务分为两类:CBQA数据集用于评估模型的世界知识,前人工作发现CBQA数据集中有train-test重叠,因此做了过滤,只用未重叠的test集,命名为Filtered TriviaQA和Filtered NQ这种;其他的下游任务用opencompass[3]框架来评测。

2.2 实验结果

用前面说的7种任务的混合数据集微调大模型,数据规模逐渐增加,然后看不同下游任务的性能表现。如下图所示,像左侧的摘要、NLI、机器翻译这类任务,随着SFT训练数据的增加,性能显著提升;但是右侧的CBQA任务,却出现断崖式下跌。

LoRAMoE_权重_03

我们已经高度怀疑CBQA的性能下降是由于大模型的世界知识崩坏引起的,为了更加确信,接下来我们仔细实验一下CBQA和大模型的世界知识到底有什么关系,具体做法是单独拿CBQA的25万条样本训练大模型,然后看大模型在未重叠的测试集上的表现。

如下图所示,在训练一开始大约1000样本的时候,性能已经快速提升到了很高的点,后续再增加更多的训练样本其实提升很有限。说明少量样本微调就帮助大模型完成了人类指令的对齐,大模型完成CBQA指标评测的能力主要依靠的是内在的世界知识,而不是微调过程中训练样本灌输的。因此我们更加确性CBQA指标高度依赖大模型在预训练过程中学到的世界知识,上图中CBQA的性能下降的原因就是世界知识的破坏。

LoRAMoE_数据_04

再进一步实验,证明是大规模的微调导致了世界知识受损。具体做法如下表:第三列仅用CBQA训练数据微调,是可以CBQA测试集上打败Baseline的;而第四列是分两阶段,先用300万不包换CBQA的数据微调,然后再用和第三列同样的CBQA数据继续微调,结果在CBQA测试集上的表现比Baseline都差很远。

对比一下第三列和第四列,差别只在后者多了一个第一阶段300万数据的微调,说明它正是大模型世界知识崩塌的罪魁祸首,第二阶段即便加上CBQA训练数据,也无法弥补回来。同时发现大模型的参数发生了巨大变化,正好和前面结论相互佐证。

LoRAMoE_数据_05

3. LoRAMoE方法

前面的实验表明,有些下游任务需要SFT的训练数据越多越好,即LLM的参数改变越大越好,而有些下游任务需要尽可能保留世界知识,即参数变化越小越好。这种冲突对于一般的全参微调或者LoRA微调都是搞不定的。论文[1]引入了MoE的思想来解决,实现LoRA微调的自适应。题外话,MoE在搜广推领域早就烂大街,这里又用到了LLM微调领域,说明技术都是相通的。

下面先分别介绍一下MoE和LoRA,然后看如何结合。

3.1 MoE简介

MoE全称Mixture of Experts,意味着有多个专家网络投票共同输出结果,只不过每个专家根据输入不同,分配不同的权重。我们可以想象成不同的专家具备不同领域的能力,然后根据输入的特征,给更匹配的专家分配更高的权重,从而动态组合专家输出。

LoRAMoE_数据集_06

3.2 LoRA简介

LoRAMoE_数据集_07

论文[4]中对缩放系数的作用只是很简略的说了一段,令人摸不着头脑。对此我自己做了推导和调研,发现原生的缩放因子并未最佳选择,因和本文主旨无关,后续单独作文述之。

3.3 LoRA+MoE

我们结合LoRA和MoE的优点,将二者组合起来,便是LoRAMoE,看下图所示:

LoRAMoE_人工智能_08

LoRAMoE_数据_09

对照公式我们再回看上面的图4,只能说图中的LoRAMoE部分是个示意图,意会即可。

3.4 专家平衡约束

如果不加任何约束微调MoE,经常会出现门控函数收敛到一种状态,即少数专家掌握了话语权,其他专家权重非常小,失去了平衡。

LoRAMoE人为地将专家分为两组,一组专注于学习下游任务,另一组专注于将世界知识和人类指令对齐。

LoRAMoE_数据_10

3.5 实验

实验参数:

LoRAMoE层只替换LLM中FFN的线性层,且每个LoRAMoE层的专家数为6,其中3个用于下游任务,另外3个用于对齐世界知识。 

LoRAMoE_人工智能_11

实验结果:

LoRAMoE_人工智能_12

可以看到,相比于全量微调或者传统的LoRA,本文的方法都取得明显提升,世界知识的遗忘问题也不再发生。详细结论不再细表,总之,LoRA结合上MoE的路由功能,让LoRA的参数增量不再是静态的死板一块,而是可以根据不同任务的输入来动态生成,有效解决了SFT大量训练数据和世界知识遗忘的冲突。