(2024,Jamba1.5,ExpertsInt8量化,LLM,激活损失)大规模混合 Transformer-Mamba 模型

Jamba-1.5: Hybrid Transformer-Mamba Models at Scale

目录

0. 摘要

2. 模型架构

3. 服务考量与改进

3.1 ExpertsInt8 量化

3.2 激活损失

4. 吞吐量和延迟分析

5 训练

5.1 训练基础设施和数据

5.2 训练阶段

5.3 后训练

5.4 一些观察

6. 评估

6.1 学术基准

6.2 聊天机器人评估 

6.3 长上下文评估 

6.4 多语言能力

8. 结论


0. 摘要

我们推出了 Jamba-1.5,这是一种基于我们 Jamba 架构的新指令微调大型语言模型。Jamba 是一种混合 Transformer-Mamba 的 MoE 架构,能够在不同上下文长度中提供高吞吐量和低内存使用,同时保持与 Transformer 模型相同或更好的质量。我们发布了两种模型尺寸:Jamba-1.5-Large,具有 94B 有效参数,和 Jamba-1.5-Mini,具有 12B 有效参数。这两个模型都经过微调,以支持各种对话和指令跟随能力,并且具有 256K tokens 的有效上下文长度,是开放权重模型中最长的。为了支持成本效益的推理,我们引入了 ExpertsInt8,这是一种新颖的量化技术,能够在处理 256K tokens 上下文时,将 Jamba-1.5-Large 适配于一台配备 8 个 80GB GPU 的机器上运行,并且不会损失质量。在一系列学术和聊天机器人基准测试中,Jamba 模型表现优异,提供高吞吐量并在长上下文基准测试中优于其他开放权重模型。

2. 模型架构

(2024,Attention-Mamba,MoE 替换 MLP)Jamba:混合 Transformer-Mamba 语言模型 

(2023,SSM,门控 MLP,选择性输入,上下文压缩)Mamba:具有选择性状态空间的线性时间序列建模

Jamba-1.5-Large 基于我们开发的 Jamba [24] 混合解码器架构,该架构结合了 Transformer 层 [36] 与 Mamba 层 [13],一种状态空间模型 (SSM) [14, 15],以及专家混合 (MoE) 模块 [8, 34]。有关此架构的详细描述,请参见 [24]。

在 Jamba [24] 的工作中,我们发现 Transformer、Mamba 和 MoE 元素的组合有助于在吞吐量、内存使用和质量之间实现平衡。Jamba-1.5-Large 在更大规模上展示了这种灵活性。

Jamba-1.5-Large 遵循相同的 Jamba 结构,但具有更大的容量。它拥有 94B 有效参数,总参数量为 398B。该模型有 9 个模块,每个模块具有以下规格:

  • 每个模块有 l = 8 层。
  • a : m = 1 : 7 的注意力层与 Mamba 层的比例。在我们关于 Jamba 的研究中发现这个比例是最优的 [24],后续的工作也证实了类似的比例成功 [6, 37]。
  • 每 e = 2 层使用 MoE 替代单一的 MLP。共有 n = 16 个专家,每个 token 选择最优的 K = 2 个。
  • 隐藏状态的维度为 8192。
  • 注意力查询头的数量为 64,KV 头的数量为 8。

表 1 将 Jamba-1.5 模型与相似规模的公开模型进行了比较。Jamba-1.5-Mini 的有效参数数量与 Mixtral 8x7B 相当,而 Jamba-1.5-Large 的有效参数数量介于 LLaMA-3.1-70B 和 Mistral-Large-2 之间。同时,Jamba 模型在 KV 缓存内存使用量(在 256K tokens 上下文下)方面远小于所有其他模型,相比同类模型减少了大约一个数量级。

在这些设置下,并结合我们的专门量化技术(第 3.1 节),Jamba-1.5-Large 可以在一台配备 8 个80GB GPU 的机器上运行,支持长达 256K tokens 的上下文长度。

(2024|ICML,Mamba2,SSD,SSM,SMA,矩阵变换,张量收缩,张量并行)Transformer 是 SSM

对于此次发布,我们还尝试了 Mamba-2 [6],这是 Mamba 的一个更快且改进的版本,据报道它在性能上超越了单独使用 Mamba 和 Transformers 的模型。然而,正如图 1 所示,我们发现,在混合架构中,Mamba-1-Attention 组合的效果优于 Mamba-2-Attention,因此我们在 Jamba-1.5-Large 中使用了 Mamba-1。我们还发现混合架构的性能优于纯 Mamba-2。我们推测这可能是因为 Mamba-2 相比于 Mamba-1 的一些优势(特别是使用更大状态尺寸的能力)在 Mamba 层之间交错全注意力层时不那么显著,因为全注意力层可以从整个上下文中汇聚信息。 

3. 服务考量与改进

我们分享了一些见解和改进,旨在实现大规模高效服务 Jamba 模型。

3.1 ExpertsInt8 量化

为了支持 Jamba-1.5-Large 的高效服务,我们开发了一种新型量化技术,称为 ExpertsInt8。我们观察到,超过 85% 的模型权重位于 MoE 层,超过 90% 位于 MoE 或 MLP 层。我们希望在保持快速 BF16 内核的好处的同时,对这些权重进行量化。为此,我们将 MoE 和 MLP 权重量化为 INT8,存储为 INT8,并在实际计算之前将它们解量化回 BF16。重要的是,解量化步骤直接发生在 vLLM [18] 中的 fused_moe 内核内部。这样,解量化过程增加的开销微乎其微,甚至在延迟上比 BF16 更有优势。【我们将这归因于内核在相对较小的权重和激活块上操作,这些块在执行计算之前从 GPU HBM 移动到 SRAM。在我们的实现中,当权重量化为 int8 时,它们从 HBM 移动到 SRAM,因此由于内存占用减少了一半,所需时间也减少了。】

我们已经将修改后的 fused_moe 内核贡献给了 vLLM。【拉取请求在此处:https://github.com/vllm-project/vllm/pull/7415】 

我们的 ExpertsInt8 方法具有几个优势。

  • 首先,它非常快速;量化过程仅需在模型加载时几秒钟。
  • 其次,与 vLLM 中大多数其他技术不同,它不依赖于需要数小时或数天且可能不稳定的校准过程。
  • 第三,我们仍然可以使用 BF16 来处理大规模激活。
  • 第四,它可以在 A100 GPU 上使用,而 FP8 仅在 H100 上可用。
  • 最后,我们的量化在延迟上与 FP8 相当,同时超越了其他量化技术,并且不会导致质量损失。

图 2 比较了使用不同量化技术的延迟,包括 Jamba-1.5-Mini、Jamba-1.5-Large 和两个 Mixtral 模型(8x78B 和 8x22B)。在 H100 GPU 上,ExpertsInt8 的延迟与 FP8 相匹配。在 A100 上,由于 FP8 不可用,ExpertsInt8 是一种有吸引力的技术,显著超越了 GPTQ [9]。结合上述 ExpertsInt8 的优点,这使得它成为服务大型 MoE 模型的一个有吸引力的量化技术。

3.2 激活损失

在预训练过程中,我们发现某些激活值,特别是特定专家的输出以及最后的 Mamba 层的输出,在处理特定输入 token 时,逐渐增大,最终达到高达 4 × 10^6 的值。尽管我们发现这并未对使用 BF16 精度进行的预训练造成损害,但这些激活值的幅度可能会在推理过程中引发数值问题,因为一些量化库仅支持 FP16 精度,而 FP16 的最大范围为 64K。

为了解决这些问题,我们添加了一个“激活损失”(Activation Loss)项,其值与前向传播中激活值的均方值成正比,并设有可配置的 α 因子,以惩罚较大的激活值。通过实验,我们发现这种辅助损失对训练没有影响,即使 α 值达到至少 10^{−3}。对于 Jamba-1.5-Large,我们使用了 α = 10^{−5},这足以将激活值减少到一个可接受的范围(最大 2K-3K)。此外,添加这一辅助损失几乎瞬间降低了激活值,使得它仅在训练结束时添加也不会影响训练速度和质量。

为了验证这种方法,我们在模型上使用 FP16 激活值运行了完整的评估套件,结果与使用 BF16 的评估结果相同,没有出现 NaN/溢出。

4. 吞吐量和延迟分析

得益于混合 Jamba 架构,我们的 Jamba-1.5 模型提供了出色的吞吐量和延迟性能。图 3 和图 4 分别展示了 Jamba-1.5-Mini 和 Jamba-1.5-Large 的表现。如图所示,我们的模型在延迟和吞吐量方面均显著优于相同规模的模型。它们在处理长上下文时展现出显著优势,存在较大的性能差距。重要的是,Jamba-1.5-Large 在处理长上下文时依然高效,而大型的 LLaMA3-405B 不能在相同硬件上运行。【注:Large 比 Mini 有更高的时延和更低的吞吐量,直观理解就是速度换性能】 

5. 训练

5.1 训练基础设施和数据

Jamba-1.5-Large 在 NVIDIA H100 GPU 上训练,使用我们内部开发的专有框架,包括 FSDP、张量并行、序列并行和专家并行。对于专家并行,我们适配了 MegaBlocks [10]。

5.2 训练阶段

该模型的训练分为三个阶段。

  • 在预训练阶段,模型首先在我们内部的数据集上进行训练,该数据集最后更新于 2024 年 3 月。我们的预训练数据集是公开的网页文档、代码、书籍和科学文章的混合体。我们的预处理流程包括解析、质量过滤和去重。为了最大化利用公开数据,我们开发了自己的解析器,并使用它来提取文本和格式。数据混合的具体组成通过各种消融实验确定。该阶段包括多语言数据,重点关注以下语言:英语、西班牙语、法语、葡萄牙语、意大利语、荷兰语、德语、阿拉伯语和希伯来语。
  • 然后,模型进行了一个短的中期训练,重点训练长文档,以强调其长距离能力。
  • 最后,模型经过了后训练(post-training),如下一节所述。

5.3 后训练

我们后训练的方法旨在同时实现两个目标:(i)为模型提供各种技能和对话能力;(ii)保留预训练中的能力,特别是中期训练中的长上下文能力。这两个目标部分存在冲突,因为大多数现有的后训练数据集包含相对较短的示例。

考虑到这些因素,我们的后训练过程包括在高质量对话数据、特定技能数据和长上下文数据上进行监督微调 [32, 39]。混合这些不同类型的数据旨在保留长上下文能力并获得所需的技能。如下面的评估所示,我们发现我们的模型在长上下文评估中表现非常好。

在进行监督微调时,我们大量使用合成数据,这在最近的基础模型中很常见 [7],并反映了我们构建复合 AI 系统 [20] 的结构化数据的方法。我们开发了多种不同的数据合成流程,针对不同的模型能力。所有流程都应用以下模式:(i)在目标分布中采样或生成提示;(ii)从语言模型中生成响应;(iii)根据自动验证和评分过滤或排名响应;(iv)后编辑以去除伪影并适应所需格式。我们使用不同的模型、提示、采样、过滤和编辑来处理不同的数据管道,从而组成最终的数据混合体。

我们基于大量主要内部的自动化指标选择了最终的训练配方(数据混合和超参数)。两个 Jamba-1.5 模型使用相同的控制 token 和格式模板进行微调,我们将这些作为我们发布的一部分,提供 HF 兼容的标记器和聊天模板;有关详细信息,请参见模型卡。

以下是一些合成数据生成的显著示例:

  • 基于表格的 QA:我们生成表格数据及其相应的问题-答案对,如我们在表格理解 [20] 中所示。然后,我们使用语言模型将表格转换为自然语言段落。我们生成的训练示例包括提取、聚合和归属任务,涉及给定表格中特定行或列的文本。

  • 文档 QA:给定一个文档,我们提示语言模型生成问题-答案对,适用于单个或多个段落。我们有时通过添加类似的文本将这些示例嵌入更长的上下文中,以鼓励长上下文理解和归属。

  • 工具使用:我们使用开源的 Glaive 函数调用数据集 [1] 作为起点,通过各种启发式方法和对输出模式的验证进行过滤。为了支持并行函数调用,我们首先为 Glaive 中的每个函数生成多个有效的参数分配。接下来,我们从这些有效的参数分配中采样子集,针对相同函数和不同函数,生成与函数调用集对应的用户请求。最后,我们提示函数调用语言模型对这些生成的用户请求做出响应,并仅保留与原始参数分配匹配的函数调用的响应。

  • 可控性:我们定义了一组可以轻松验证的指令,并合成了包括一个或多个约束的通用文档草拟任务的提示。我们从语言模型中生成这些提示的完成,并基于对我们细粒度指令的验证以及通用奖励模型进行拒绝采样。为了支持系统消息中的指令,我们选择了多种这种类型的提示,这些提示共享细粒度指令实例,并将这些提示重新格式化为多轮对话,将指令移到系统消息中。

5.4 一些观察

我们分享了一些关于 Jamba-1.5 开发过程中的观察。这些观察虽然尚未完全探索,但希望能够激发社区进一步研究这些问题。首先,虽然我们仅在后训练阶段包含了非常小比例的非英语数据,仅针对几个语言和特定技能,我们的 Jamba-1.5 模型在多语言环境下表现相当出色。我们确实在预训练阶段包括了多语言数据,如上所述。因此,我们推测模型能够在主要以英语进行后训练时,利用预训练阶段学到的知识。

其次,我们高效的 Jamba 架构降低了在长上下文上的微调成本,使我们能够在给定预算下进行更多实验。因此,我们能够在后训练阶段实验多种不同的训练配方。

最后,尽管偏好微调算法如 PPO [33] 或 DPO [29] 可以改善模型输出与人类意图之间的对齐,我们发现细致的合成数据生成、数据过滤和监督微调的结合对获得强大的后训练模型至关重要。

6. 评估

虽然我们相信基准测试仅与实际应用的成功和用户满意度部分相关,但我们仍报告了一些关键公共基准的结果。首先,我们报告标准学术基准的结果。然后,我们评估模型在聊天机器人基准上的表现。最后,我们在多个长上下文评估和多语言评估中评估 Jamba-1.5-Large。

我们将其与最近的相同规模范围的开放权重模型进行比较:在比较 Jamba-1.5-Large 时,与 LLaMA-3.1 70B 和 Mistral-Large-2-123B;在比较 Jamba-1.5-Mini 时,与 LLaMA-3.1-8B 和 Gemma-2-9B 进行比较。

6.1 学术基准

6.2 聊天机器人评估 

在本节中,我们对 Jamba-1.5 模型在两个聊天机器人(chatbot)场景中的表现进行了吞吐量评估:

  • Arena-Hard [22],这是一个包含 500 个具有挑战性的用户查询的数据集,使用 GPT4-Turbo 作为评判;
  • WildBench [25],该数据集也使用 GPT4-Turbo 作为评判,但进行了长度偏差的缓解处理。 

6.3 长上下文评估 

我们在 RULER 基准上进行评估,RULER 是一组 13 个合成任务,旨在评估语言模型的长上下文能力。RULER 包括 8 种变体的针在大 haystack 检索任务 [17, 21, 27, 28],其中包含多个“针” [2]。此外,它还有一个变量跟踪任务,需要返回一系列变量绑定,两个聚合任务,需要返回最常见的词汇,以及两个问答任务,其中包含来自自然数据集 [30, 41] 的答案的段落被插入到随机段落中,以模拟长上下文。

接下来,我们在 ∞BENCH 数据集上进行评估,该数据集旨在评估语言模型的长上下文能力,平均长度为 100K 词汇。我们专注于两个英文任务,理解长篇小说:问答(EN.QA)和多项选择题问答(EN.MC)。如表 5 所示,Jamba-1.5 模型在这方面表现非常出色,超越了同样规模的 LLaMA-3.1 和 Mistral-Large-2 模型。(由于 Gemma-2 9B 的上下文窗口较短(8K),我们未报告其结果。)

6.4 多语言能力

我们对 Jamba-1.5 在非英语语言中的能力进行了基本评估。具体来说,我们报告了在多语言 MMLU 数据集 [19] 上的结果,该数据集通过 LM Evaluation Harness [11] 分发。如表 6 所示,Jamba-1.5-Mini 的表现与对比模型相当或更好。Jamba-1.5-Large 略微落后于其可比模型,但仍展现出良好的多语言能力。

8. 结论

我们介绍了 Jamba-1.5-Large 和 Jamba-1.5-Mini,这两个基于 Jamba 混合 Transformer-Mamba 架构的大规模模型。两个模型在学术基准、聊天机器人评估和长上下文评估中均表现出色,同时提供了改进的延迟和吞吐量,特别是在处理长上下文时。我们发布了模型权重,希望社区能够使用这些模型并在此技术基础上进行进一步开发。

 

论文地址:https://arxiv.org/abs/2408.12570

Jamba 开源模型许可:https://www.ai21.com/licenses/jamba-open-model-license。

项目页面:https://huggingface.co/ai21labs

Jamba-1.5-Mini:https://huggingface.co/ai21labs/AI21-Jamba-1.5-Mini

Jamba-1.5-Large:https://huggingface.co/ai21labs/AI21-Jamba-1.5-Large

公和众与号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
加 VX 群请备注学校 / 单位 + 研究方向
CV 进计算机视觉群
KAN 进 KAN 群

  • 13
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值