ADDITION IS ALL YOU NEED FOR ENERGY-EFFICIENT LANGUAGE MODELS 论文理解——L-Mul

原文地址:https://arxiv.org/html/2410.00907
简单来说该论文提出了一种方法用于替代大型神经网络中浮点张量的乘法,该方法具有更高的精确度和更低的计算量。
下文给出论文各部分的粗略理解


指数部分乘法优化

一个nbits, e位指数, m位底数的浮点运算为

Mul ( x , y ) = ( 1 + x m ) ⋅ 2 x e ⋅ ( 1 + y m ) ⋅ 2 y e = ( 1 + x m + y m + x m ⋅ y m ) ⋅ 2 x e + y e \text{Mul}(x, y) = (1 + x^m) \cdot 2^{x^e} \cdot (1 + y^m) \cdot 2^{y^e}\\= (1 + x^m + y^m + x^m \cdot y^m) \cdot 2^{x^e + y^e} Mul(x,y)=(1+xm)2xe(1+ym)2ye=(1+xm+ym+xmym)2xe+ye

在该运算中,瓶颈复杂度来源于 x m ⋅ y m x^m \cdot y^m xmym 复杂度为O(m^2)
论文将该操作替换为一个新的乘法算法:
L-Mul ( x , y ) = ( 1 + x m + y m + 2 − l ( m ) ) ⋅ 2 x e + y e \text{L-Mul}(x, y) = \left(1 + x^m + y^m + 2^{-l(m)}\right) \cdot 2^{x^e + y^e} L-Mul(x,y)=(1+xm+ym+2l(m))2xe+ye
l ( m ) = { m if  m ≤ 3 , 3 if  m = 4 , 4 if  m > 4. l(m) =\begin{cases}m& \text{if } m \leq 3,\\3 & \text{if } m = 4,\\4& \text{if } m > 4.\end{cases} l(m)= m34if m3,if m=4,if m>4.

该算法复杂度为O(m)
此外,在传统浮点乘法器中,如果尾数和大于2,则需要手动舍位到1.x
而使用该方法时,除了尾数乘法,还跳过了手动舍位
请添加图片描述


基于L-Mul的注意力机制

K = H ⋅ W k , Q = H ⋅ W q , V = H ⋅ W V A = s o f t m a x [ L − m a t m u l ( Q , K T ) d ] , H ′ = L − m a t m u l ( A , H ) K=H\cdot W_{k},Q= H \cdot W_{q}, V=H \cdot W_{V}\\A=softmax \left[ \frac{\mathcal{L}-matmul(Q,K^{T})}{\sqrt{d}}\right],H^{'}=\mathcal{L}-matmul(A,H) K=HWk,Q=HWq,V=HWVA=softmax[d Lmatmul(Q,KT)],H=Lmatmul(A,H)


给出准确度和消耗分析

L − m u l \mathcal{L}-mul Lmul f p 8 _ e 4 m 3 fp8\_e4m3 fp8_e4m3更精确,比 f p _ e 5 m 2 fp\_e5m2 fp_e5m2计算量更低

考虑 x = ( 1 + x m ) ⋅ 2 x e x=(1 + x_m) \cdot 2^{x^e} x=(1+xm)2xe y = ( 1 + y m ) ⋅ 2 y e y=(1 + y_m) \cdot 2^{y^e} y=(1+ym)2ye
如果拆解为
x = ( 1 + x k + x r ) ⋅ 2 x e , x ′ = ( 1 + x k ) ⋅ 2 x e x=(1 + x_k+x_r) \cdot 2^{x^e},x^{'}=(1 + x_k) \cdot 2^{x^e} x=(1+xk+xr)2xe,x=(1+xk)2xe
y = ( 1 + y k + y r ) ⋅ 2 y e , y ′ = ( 1 + y k ) ⋅ 2 y e y=(1 + y_k+y_r) \cdot 2^{y^e},y^{'}=(1 + y_k) \cdot 2^{y^e} y=(1+yk+yr)2ye,y=(1+yk)2ye
其中 x k , y k x_k,y_k xk,yk x m , y m x_m,y_m xm,ym的前 k k k位, x r , y r x_r,y_r xr,yr k k k位取整后被忽略的值。

考虑 x x x y y y具有 m m m位尾数的全部精度,比如Float16的尾数为10-bit,BFloat16尾数为7. 那么 M u l ( x , y ) Mul(x,y) Mul(x,y)的误差和期望为
e m u l k = M u l ( x , y ) − M u l ( x ′ , y ′ ) = ( x k y r + y k x r + x r + y r + x r y r ) ⋅ 2 x e + y e e^{k}_{mul}= Mul(x,y)-Mul(x^{'},y^{'})=(x_{k}y_{r}+y_{k}x_{r}+x_{r}+y_{r}+x_{r}y_{r})\cdot2^{x_{e}+y_{e}} emulk=Mul(x,y)Mul(x,y)=(xkyr+ykxr+xr+yr+xryr)2xe+ye
E [ e m u l k ] = f 1 ( m , k ) ⋅ E [ 2 x e + y e ] E[e^{k}_{mul}]=f_{1}(m,k)\cdot E[2^{x_{e}+y_{e}}] E[emulk]=f1(m,k)E[2xe+ye]
则期望的误差为
E [ e l m u l k ] − E [ e m u l k ] = f 2 ( k ) ⋅ E [ 2 e x + e y ] , E [ e l m u l k ] = [ f 1 ( m , k ) + f 2 ( k ) ] ⋅ E [ 2 e x + e y ] E[e^{k}_{lmul}]-E[e^{k}_{mul}]=f_2(k)\cdot E[2^{e_x+e_y}], E[e^{k}_{lmul}]=[f_1(m,k)+f_2(k)]\cdot E[2^{e_x+e_y}] E[elmulk]E[emulk]=f2(k)E[2ex+ey],E[elmulk]=[f1(m,k)+f2(k)]E[2ex+ey]

x m , y m x_m,y_m xm,ym均匀分布时
E [ x k ] = 1 2 ( 1 − 2 − k ) , E [ x r ] = 1 2 ( 2 − k − 2 − m ) E[x_k]=\frac{1}{2} (1-2^{-k}), E[x_r]=\frac{1}{2}(2^{-k}-2^{-m}) E[xk]=21(12k),E[xr]=21(2k2m)

通过估计 f 1 ( m , k ) , f 2 ( k ) f_1(m,k),f_2(k) f1(m,k),f2(k)并推断 E [ e k l m u l ] , E [ e k m u l ] E[e^{k}{lmul}], E[e^{k}{mul}] E[eklmul],E[ekmul],论文发现 L − M u l \mathcal{L}-Mul LMul在均匀分布下精度超过 f p _ e 5 m 2 fp\_e5m2 fp_e5m2。然而在预训练的LLMs中实际权重分布往往存在偏差。基于五种常用LLMs的组合分布,论文发现 L − M u l \mathcal{L}-Mul LMul在5-bit尾数操作数下精度超过 f p _ e 4 m 3 fp\_e4m3 fp_e4m3。 并给出了误差估计表:
在这里插入图片描述


门复杂度估计:

对于两个fpn_eimj数相乘需要:

  • sign prediction,符号预测
  • exponent addition with offset, 整数部分加法和进位
  • a j+1-bitmantissa multiplication, 尾数乘法
  • exponent rounding. 整数部分取整

尾数乘法需要(j+1)^{2} ADD操作,3个半加器,和2j-2全加器。
取整需要i个半加器,
一个全加器需要2AND,2XOR,1OR,每个XOR需要4NAND,也就是最后需要11个门器件。而一个半加器需要5个门器件。
最后的统计结果约为:
N f p 16 × ≈ 584 , N f p 8 − e 4 m 3 × ≈ 325 , N f p 8 − e 5 m 2 × ≈ 296 N^{\times}_{fp16}\approx 584, N^{\times}_{fp8-e4m3}\approx325, N^{\times}_{fp8-e5m2}\approx296 Nfp16×584,Nfp8e4m3×325,Nfp8e5m2×296


N e i m j L − m u l = N 1 ⊕ + N i n t ( i + j ) + + N i n t 8 + N^{\mathcal{L}-mul}_{eimj}=N^{\oplus}_{1}+N^{+}_{int(i+j)}+N^{+}_{int8} NeimjLmul=N1+Nint(i+j)++Nint8+
所以
N f p 16 L − m u l ≈ 256 , N f p 8 L − m u l ≈ 157 N^{\mathcal{L}-mul}_{fp16}\approx256, N^{\mathcal{L}-mul}_{fp8}\approx157 Nfp16Lmul256,Nfp8Lmul157


实验部分省略


介绍一下论文中的related works,介绍神经网络轻量化的典型思路。
1.修剪,减少层间连接数。确定重要的权重,然后重新训练以更新特定任务所选择的权重。\mathcal{L}-Mul针对一般任务设计,不需要针对特定任务再训练。
2.优化张量I/O。在SRAM和HBM(高带宽内存)间移动张量是时间和能量消耗的主要瓶颈。减少transformer中的I/O操作并充分利用HBM可以显著提高AI训练和推理的效率。\mathcal{L}-Mul与该方法正交。
3.取整和量化。标准张量格式为16bit或32bit。然而全尺寸的权重需要相当多的GPU内存。为了提高存储效率,权重存储和计算都可以在较低精度下进行,例如使用16bit、8bit、4bit FP 和Int。比如fp16,bf16,fp8-e4m3,fp8-e5m2,int8,fp4,int4张量来表达模型权重。 与这类方法相比,\mathcal{L}-Mul具有更高的精度和更低的计算量。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值