文章链接:HET: Scaling out Huge Embedding Model Training via Cache-enabled Distributed Framework
登载刊物:Proceedings of the VLDB Endowment
源码链接: https://github.com/PKU-DAIR/Hetu
1.背景介绍
现有的分布式训练框架面临着嵌入模型的可扩展性问题,因为从服务器更新和检索共享的嵌入参数通常占训练周期的主导地位。
为了在高维数据(比如文本中的单词)上训练模型,通常需要使用嵌入式模型(Embedding Model,MD),MD可以将稀疏的高维特征空间投影到连续的低维嵌入空间中。比如NLP中会使用嵌入层将文本投射成一个个连续的ID号,再进行训练。MD可以提取一些简单的信息,并处理特定的下游任务,比如推荐、语句定调等
由于每个特征需要由一组嵌入向量(Embedding Vectors,EV)来表示,而MD的规模与EV大小存在正相关,当稀疏的高维数据很庞大时,EV也会表示的非常长,相对应的MD规模也会增加。这就导致许多MD难以复制到一个机器上运行,甚至难以存储在一个主机的内存中。例如,Google 的真实世界文档嵌入模型的参数占用几个TB,百度中的商业点击率预测模型具有1011个输入稀疏特征,也需要10TB参数。
现代分布式ML系统通常采用参数服务器(Parameter Server,PS)框架来扩展模型。服务器通常通过聚合来自工作器的更新并更新全局参数来维护全局共享参数。工作进程只与服务器节点通信,更新和检索共享参数。 现有的ML系统通常支持数据并行,其中工作者通常包含机器学习模型的副本,并被分配给整个训练数据的相等大小的分区。在分布式训练期间,通常采用批量同步并行(BSP)或异步并行(ASP)来更新模型参数。
BSP更新参数
x ( t + 1 ) = x ( t ) − η [ 1 N ∑ N i = 1 G i ( x ( t ) ; ξ i ) ] x(t+1) = x(t)-\eta \left [ \frac{1}{N} \sum_{N}^{i=1} G^{i}(x(t);\xi ^{i})\right ] x(t+1)=x(t)−η[N1N∑i=1Gi(x(t);ξi)]
- η \eta η: learning rate
- ξ i \xi ^{i} ξi: randomly sampled from thetraining set
- G i ( ⋅ ) G^{i}(·) Gi(⋅):gradient from the 𝑖-th worker
ASP更新参数
x ( t + 1 ) = x ( t ) − η ⋅ G i ( x ( t ) ; ξ i ) x(t+1) = x(t)-\eta· G^{i}(x(t);\xi ^{i}) x(t+1)=x(t)−η⋅Gi(x(t);ξi)
异步并行训练中,每个worker可以独立地计算和更新梯度,而不需要等待其他worker的结果。则可以避免同步和掉队导致的通信开销。但由于梯度的陈旧性,模型会退化。
这种设置面临着大型嵌入模型的可扩展性问题,最大的效率低下来自更新和检索共享的特征嵌入参数。例如在动态网页(ASP)中使用TensorFlow,即跑线上模型比如ChatGPT,高达86%的训练时间花在嵌入获取和更新上,其原因在于EM使用的深度神经网络计算复杂度较低,而嵌入数据则十分庞大。
随着新兴的强大GPU之间的差距越来越大,以及网络带宽的缓慢增长,嵌入式通信瓶颈将变得更加严重。这都将导致分布式的框架的扩展变得更加困难。
2.内容摘要
现有的分布式训练框架面临嵌入模型的可扩展性问题。HET集成了一个嵌入缓存,该缓存利用了嵌入的倾斜流行分布,并利用它来解决通信瓶颈。作者还引入了一个新的一致性模型,该模型在每个嵌入的基础上提供了细粒度的一致性保证。HET实现了高达88%的嵌入通信减少,并在最先进的基线上实现了高达20.68倍的性能加速。
挑战和机遇
A 通信成本
大规模EM同时受到密集参数和稀疏嵌入的通信瓶颈,而后者占主导地位。
可以看到除个别模型外,大多数模型的通讯时间占训练时间的70%以上,并且系数嵌入规模大的模型,通讯时间的占比也高。
B 偏斜现象
上图展示了一些流行工作负载上的嵌入更新频率的偏斜分布,包括点进率预测(Criteo)、引文网络(OGBN-MAG)以及产品购物网站(Amazon)。Criteo中头部10%的流行嵌入可以占总更新数量的90%。这种偏斜分布非常普遍,例如推荐模型,LDA主题模型和图学习模型中都有案例。
C 鲁棒性
引入缓存会带来在存在写入时确保一致性的问题。现有的嵌入模型属于迭代收敛算法的范畴,该算法从随机选择的初始点开始,通过迭代地重复一组过程来收敛到最优值。这种迭代收敛过程已被证明对有界量的不一致性具有鲁棒性,并且仍然正确收敛。此属性允许框架通过放宽缓存一致性模型和从本地(过期)缓存阅读来提高系统性能。
模型解读
HET将训练数据分发到多个worker。每个worker持有一个密集模型参数的复制,并在训练期间使用AllReduce进行梯度同步。HET将共享的嵌入参数组织为表格: 整个嵌入参数存储在HET服务器上的全局嵌入表中。客户端负责管理本地缓存,通过与服务器通信来控制本地和全局嵌入表的不一致性。
每个嵌入表示由唯一键表示的稀疏特征ID:k,这些嵌入被按行集合在嵌入表中。全局嵌入表中的
x
k
x_{k}
xk记录了一个全局Lamport时钟;
c
g
c_{g}
cg表示该嵌入的更新总数。每个worker可以在本地缓存嵌入表的一小部分,称为缓存嵌入表。缓存嵌入表的
x
k
i
x^{i}_{k}
xki有两个时钟:一个是
c
s
c_{s}
cs表示上次从服务器获取嵌入
x
k
x_{k}
xk到
w
o
r
k
e
r
i
worker_{i}
workeri时观察到的全局时钟;还有一个
c
c
c_{c}
cc记录本地当前更新的本地时钟时间。
从逐嵌入的角度来看,在训练过程中,每一个嵌入可能存在于多个缓存嵌入表中。为了缓存一致性保证,作者在读写k行内容时设计了一个Cache.CheckVaild(k) 方法。该方法对本地 x k i . c c x^{i}_{k}.c_{c} xki.cc提出了两个限制:
- 当前时钟不应太早于开始时钟: x k i . c c ≤ x k i . c s + s x^{i}_{k}.c_{c}≤x^{i}_{k}.c_{s}+s xki.cc≤xki.cs+s
- 当前时钟不应与全局时钟相差太大: x k . c g ≤ x k i . c s + s x_{k}.c_{g}≤x^{i}_{k}.c_{s}+s xk.cg≤xki.cs+s
篇幅有限,这里就不进行后续展开讲解了,经总结证明,在HET框架下,算法是保证收敛的。
关键技术
1.热缓存
受嵌入模型的关键特征的启发:受欢迎程度偏度和耐老化性。具体来说,嵌入的流行度分布通常是高度偏斜的,通常遵循幂律分布。
这个特性意味着提高性能机会:在每个工作者处的热嵌入(比如高频词)的小缓存可以有效地保存网络带宽,同时随着工作者的数量缩放训练吞吐量。
2. 一致性设计
在多个高速缓存中复制共享的嵌入数据在存在写入的情况下引起一致性问题。幸运的是,EM是迭代收敛算法,在训练过程中的一些陈旧错误是可以接受的,不会阻止收敛。换句话说,EM对于有界量的不一致性是鲁棒的(比如过期的共享状态,只会延缓收敛速度)。通过适当地放松一致性保证,就可以利用缓存来获得显着的系统改进。
文章提出了嵌入时钟界一致性的概念,借此工作者可以看到比某些嵌入特定时钟更早的嵌入的所有更新。
实验结果
收敛速度对比:在HET框架,缓存表S越大,收敛速率越快,因为使用更大的S可以减轻更多的通信成本。
通讯时间加速展示:由于细粒度缓存和一致性,HET实现了显著的性能改进,并减少了高达88%的嵌入通信。在不同大小带宽的以太网集群中,训练的通讯时间都被极大的加速了。
收敛质量:左表展示了不同老化阈值下的收敛模型性能(即测试AUC),结果显示尽管在高老化程度下模型退化变得明显,但在中等老化程度(S = 100)下也能达到目标模型质量。
右表展示了缓存未命中的概率。由于缓存命中(未命中)意味着预测使用陈旧的(最新的)嵌入参数,所以使用缓存未命中率来度量使用陈旧的(同步频率较低的)嵌入参数的预测频率。来自两个模型(= 0和= 100)的预测分布非常接近,这表明旧的嵌入不会引起显著的预测偏差。
3.文章总结
在高维数据上训练的EM在现代网络公司中很常见,并对标准框架提出了额外的挑战:高通信开销导致嵌入工作负载具有低的执行效率和可伸缩性。为了解决这个性能瓶颈,我们提出了HET,一个系统框架,利用嵌入式缓存架构结合细粒度的一致性和陈旧的写协议。实验结果表明,与最先进的基线相比,HET可以减少高达88%的嵌入通信,并实现高达20.68倍的性能改进。