深度解读 Llama 3.1 技术报告:从405B参数到24000块H100集群

Meta在最近发布了他们的开源大模型Llama 3.1,引起了广泛的关注和讨论。随着热度逐渐退潮,Llama 3.1 的详细技术报告也公开了。这份报告揭示了很多关于Llama 3.1 的技术细节和实现原理。本文将从模型参数、基础设施、预训练、后训练等方面,详细剖析Llama 3系列模型的关键技术。

一、模型参数

Llama 3.1 最引人注目的地方就是其高达405B的庞大参数。虽然现今已经不再是纯粹的规模为王的时代,但Meta通过Scaling Law(缩放定律)来确定旗舰模型的最佳大小,以保证在特定基准的性能和经济效益之间找到平衡。

1.1 Scaling Law的挑战

Scaling Law的应用存在两个主要挑战:

  1. 现有的Scaling Law通常只能预测下一个词的损失,不能预测特定基准的性能。
  2. Scaling Law可能会因为小计算量的预训练变得嘈杂和不可靠。

1.2 两阶段方法

Meta采用了两阶段方法来解决上述挑战:

  1. 确定最优模型在下游任务上的负对数似然(NLL)和训练FLOPs之间的相关性。
  2. 利用Scaling Law模型和更高计算FLOPs训练的旧模型,将NLL与基准任务的准确率相关联。

具体步骤如下:

  1. 保证在6×10^18 FLOPs到10^22 FLOPs之间的计算预算。
  2. 通过预训练模型构建Scaling Law,使用余弦学习率计划预热2000个训练步骤,峰值学习率设定在2×10-4到4×10-4范围内,并将余弦衰减设置为峰值的0.1。

二、基础设施

为了支持Llama 3.1 的训练,Meta整合了24000多块H100 GPU,重新搭建了生产集群。这个集群配备了80GB的HBM3和Meta的Grand Teton AI服务器平台,每个服务器配备八个GPU和两个CPU。

2.1 硬件架构

在服务器内部,八个GPU通过NVLink连接,模型的训练使用了Arista 7800交换机和Minipack2 OCP交换机,采用RoCE网络拓扑结构。通过三层CLOS网络连接,形成了一个拥有3072个GPU的pod,最终形成了24000个GPU的集群。

2.2 存储网络

为了支持庞大的模型训练,Meta构建了一个分布式文件系统,提供高达240PB的存储空间,并支持每秒2TB的持续吞吐量和每秒7TB的峰值吞吐量。通过这一系列设置,Meta希望将检查点期间的GPU暂停时间最小化,同时增加检查点的频率,从而减少恢复后丢失的工作量。

2.3 负载均衡和拥塞控制

由于大语言模型的训练会产生大量的网络流量,Meta采用了增强等价多路径路由(E-ECMP)协议,对RoCE数据包头中的额外字段进行哈希,从而在不同的网络路径上平衡网络流。此外,Meta还采用了深缓冲交换机,在骨干网络上进行部署,以降低集合通信模式引发的瞬间拥塞和缓冲问题。

三、预训练

Llama 3.1 的预训练数据包含截至2023年末的各种数据源,经过多次去重和数据清洗,以保证获得高质量的Token,同时删除大量个人身份信息和成人内容。

3.1 数据混合和退火

Meta开发了一个分类器对网络数据进行分类,并通过Scaling Law实验确定最佳的数据混合比例。最终的混合数据集中包含大约50%的一般知识token,25%的数学和推理token,17%的代码token,以及8%的多语言token。此外,Meta还通过退火技术对高质量数据进行上采样,以提高预训练模型的性能。

3.2 训练方法

405B模型的预训练采用了余弦学习率计划,峰值学习率为8×10-5,线性预热8000步,然后在1200000个训练步骤内衰减到8×10-7。Meta发现,在训练初期使用较小的批量大小可以提高训练的稳定性,随后再增加批量大小,从而提高效率。

四、后训练

后训练阶段主要包括奖励和微调模型。Meta使用人类标注偏好数据训练的奖励模型,以及监督式微调(SFT)和直接偏好优化(DPO)。

4.1 奖励模型和SFT

基于最后405B的检查点,Meta训练了一个涵盖不同能力的奖励模型,并使用偏好数据进行了奖励建模,标注被划分成四个偏好等级。随后,研究人员使用奖励模型对人工标注的提示进行了拒绝采样,并将拒绝采样数据和其他数据源合并,再使用标准的交叉熵损失对预训练语言模型进行了监督式微调。

4.2 DPO

在SFT之后,Meta进一步使用DPO对SFT模型进行训练,以便与人类的偏好对齐。相比于PPO,DPO针对大参数模型的计算量更少,并且性能更好。此外,为了提高DPO训练的稳定性,研究人员对DPO进行了多项修改,包括屏蔽特殊格式的Token,并添加了一个额外的负对数似然损失项。

五、推理

Llama 3.1 的405B模型在FP16推理时至少需要810GB的显存,至少要两台装有8个H100的服务器。如果服务器之间有NVLink和NVSwitch高速互联,可以使用张量并行,而在带宽较低或延迟较长的情况下,则需要使用流水线并行,并使用微批处理来提高吞吐量。

FP8推理则只需要一台服务器即可部署,不仅能够让预填充阶段的吞吐量提高50%,而且在解码阶段也能获得更好的吞吐量-延迟权衡。

结论和未来展望

Llama 3.1 技术报告详细揭示了Meta在大模型训练上的诸多技术细节和挑战。通过分析这些细节,可以看出Meta在硬件架构、数据处理、训练方法等方面都进行了大量的创新和优化。未来,随着模型和集群规模的进一步扩大,Meta还将面临更多的挑战,但可以预见的是,这些技术的进步将为大语言模型的发展带来更多可能性。
在这里插入图片描述

  • 13
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值