Flash Attention Softmax计算分块梳理

Flash Attention Softmax分块计算梳理

在Transformer中,对于一个 M ∈ R ( 2 n , 2 n ) M\in R^{(2n,2n)} MR(2n,2n)的矩阵,Softmax是按行计算的。
假设 M M M中某一行 X = [ x 1 , x 2 , x 3 , . . . , x 2 n ] X=[x_1,x_2,x_3,...,x_{2n}] X=[x1,x2,x3,...,x2n]

经典Softmax

s o f t m a x ( [ x 1 , x 2 , . . . , x 2 n ] ) = { e x i ∑ j 2 n e x j } i = 1 2 n softmax([x_1,x_2,...,x_{2n}])= \{{\frac {e^{x_i}}{\sum_j^{2n}e^{x_j}}}\}_{i=1}^{2n} softmax([x1,x2,...,x2n])={j2nexjexi}i=12n

半精度Softmax(safe_softmax)

为防止计算溢出,需要对指数进行缩放
m = m a x ( [ x 1 , x 2 , . . . x 2 n ] ) m=max([x_1,x_2,...x_{2n}]) m=max([x1,x2,...x2n])
s a f e _ s o f t m a x ( [ x 1 , x 2 , . . . , x 2 n ] ) = { e x i / e m ∑ j 2 n ( e x j / e m ) } i = 1 2 n = { e x i − m ∑ j 2 n e x j − m } i = 1 2 n safe\_softmax([x_1,x_2,...,x_{2n}])= \{{\frac {e^{x_i}/e^m}{\sum_j^{2n}(e^{x_j}/e^m)}}\}_{i=1}^{2n} = \{{\frac {e^{x_i-m}}{\sum_j^{2n}e^{x_j-m}}}\}_{i=1}^{2n} safe_softmax([x1,x2,...,x2n])={j2n(exj/em)exi/em}i=12n={j2nexjmexim}i=12n

Flash Attention Softmax

首先将safe_softmax计算过程拆解

对于 X = [ x 1 , x 2 , x 3 , . . . , x 2 n ] X=[x_1,x_2,x_3,...,x_{2n}] X=[x1,x2,x3,...,x2n]

  1. m a x ( X ) = m max(X) = m max(X)=m
  2. f u n ( X ) = [ e x 1 − m , e x 2 − m , . . . , e x 2 n − m ] fun(X) = [e^{x_1-m},e^{x_2-m},...,e^{x_{2n}-m}] fun(X)=[ex1m,ex2m,...,ex2nm]
  3. s ( X ) = s u m ( f u n ( X ) ) s(X) = sum(fun(X)) s(X)=sum(fun(X))
  4. s a f t _ s o f t m a x ( X ) = f u n ( X ) s ( X ) saft\_softmax(X) = \frac{fun(X)}{s(X)} saft_softmax(X)=s(X)fun(X)

分块计算

X = [ x 1 , x 2 , x 3 , . . . , x 2 n ] X=[x_1,x_2,x_3,...,x_{2n}] X=[x1,x2,x3,...,x2n] 划分为 X 1 = [ x 1 , x 2 , x 3 , . . . , x n ] X_1=[x_1,x_2,x_3,...,x_{n}] X1=[x1,x2,x3,...,xn] X 2 = [ x n + 1 , x n + 2 , x n + 3 , . . . , x 2 n ] X_2=[x_{n+1},x_{n+2},x_{n+3},...,x_{2n}] X2=[xn+1,xn+2,xn+3,...,x2n]

X 1 , X 2 X_1,X_2 X1,X2分别开展计算:

  1. 分别计算 X 1 , X 2 X_1,X_2 X1,X2的最大值记为 m 1 = m a x ( X 1 ) m_1 = max(X_1) m1=max(X1) , m 2 = m a x ( X 2 ) m_2 = max(X_2) m2=max(X2)
  2. f 1 = f u n ( X 1 ) f_1 = fun(X_1) f1=fun(X1), f 2 = f u n ( X 2 ) f_2 = fun(X_2) f2=fun(X2)
  3. 比较两组数据中的最大值: m = m a x ( m 1 , m 2 ) m = max(m_1,m_2) m=max(m1,m2)
  4. 根据最大值 m m m更新两组数据的计算结果 f u n ( X ) = [ e m 1 − m f 1 , e m 2 − m f 2 ] fun(X) = [e^{m_1-m}f_1,e^{m_2-m}f_2] fun(X)=[em1mf1,em2mf2]
  5. 求和 s ( X ) = s u m ( f u n ( X ) ) s(X) = sum(fun(X)) s(X)=sum(fun(X))
  6. s a f e _ s o f t m a x ( X ) = f u n c ( X ) s ( X ) safe\_softmax(X) = \frac{func(X)}{s(X)} safe_softmax(X)=s(X)func(X)
简单举例,对于 X = [ 1 , 2 , 3 , 4 , 5 , 6 ] X=[1,2,3,4,5,6] X=[1,2,3,4,5,6]:
  • X 1 = [ 1 , 2 , 3 ] , X 2 = [ 4 , 5 , 6 ] X_1=[1,2,3],X_2=[4,5,6] X1=[1,2,3],X2=[4,5,6]
  • m 1 = m a x ( [ 1 , 2 , 3 ] ) = 3 , m 2 = m a x ( 4 , 5 , 6 ) = 6 m_1=max([1,2,3])=3,m_2=max(4,5,6)=6 m1=max([1,2,3])=3,m2=max(4,5,6)=6
  • f 1 = [ e ( 1 − m 1 ) , e ( 2 − m 1 ) , e ( 3 − m 1 ) ] = [ e − 2 , e − 1 , e 0 ] f 2 = [ e ( 4 − m 2 ) , e ( 5 − m 2 ) , e ( 6 − m 2 ) ] = [ e − 2 , e − 1 , e 0 ] f_1=[e^{(1-m_1)},e^{(2-m_1)},e^{(3-m_1)}] = [e^{-2},e^{-1},e^{0}] \\ f_2=[e^{(4-m_2)},e^{(5-m_2)},e^{(6-m_2)}] = [e^{-2},e^{-1},e^{0}] f1=[e(1m1),e(2m1),e(3m1)]=[e2,e1,e0]f2=[e(4m2),e(5m2),e(6m2)]=[e2,e1,e0]
  • m = m a x ( m 1 , m 2 ) = 6 m = max(m_1,m_2) = 6 m=max(m1,m2)=6
  • f u n ( X ) = [ e m 1 − m f 1 , e m 2 − m f 2 ] = [ e − 3 f 1 , e 0 f 2 ] = [ e − 5 , e − 4 , e − 3 , e − 2 , e − 1 , e 0 ] fun(X)=[e^{m_1-m}f_1,e^{m_2-m}f_2]=[e^{-3}f_1,e^0f_2]\\ =[e^{-5},e^{-4},e^{-3},e^{-2},e^{-1},e^{0}] fun(X)=[em1mf1,em2mf2]=[e3f1,e0f2]=[e5,e4,e3,e2,e1,e0]
  • s ( X ) = s u m ( f u n ( X ) ) = e − 5 + e − 4 + e − 3 + e − 2 + e − 1 + e 0 s(X) = sum(fun(X)) = e^{-5}+e^{-4}+e^{-3}+e^{-2}+e^{-1}+e^{0} s(X)=sum(fun(X))=e5+e4+e3+e2+e1+e0
  • s a f e _ s o f t m a x ( X ) = f u n c ( X ) s ( X ) safe\_softmax(X) = \frac{func(X)}{s(X)} safe_softmax(X)=s(X)func(X) 此步计算结果与safe_softmax一致
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值