Flash Attention Softmax分块计算梳理
在Transformer中,对于一个
M
∈
R
(
2
n
,
2
n
)
M\in R^{(2n,2n)}
M∈R(2n,2n)的矩阵,Softmax是按行计算的。
假设
M
M
M中某一行
X
=
[
x
1
,
x
2
,
x
3
,
.
.
.
,
x
2
n
]
X=[x_1,x_2,x_3,...,x_{2n}]
X=[x1,x2,x3,...,x2n]
经典Softmax
s o f t m a x ( [ x 1 , x 2 , . . . , x 2 n ] ) = { e x i ∑ j 2 n e x j } i = 1 2 n softmax([x_1,x_2,...,x_{2n}])= \{{\frac {e^{x_i}}{\sum_j^{2n}e^{x_j}}}\}_{i=1}^{2n} softmax([x1,x2,...,x2n])={∑j2nexjexi}i=12n
半精度Softmax(safe_softmax)
为防止计算溢出,需要对指数进行缩放
m
=
m
a
x
(
[
x
1
,
x
2
,
.
.
.
x
2
n
]
)
m=max([x_1,x_2,...x_{2n}])
m=max([x1,x2,...x2n])
s
a
f
e
_
s
o
f
t
m
a
x
(
[
x
1
,
x
2
,
.
.
.
,
x
2
n
]
)
=
{
e
x
i
/
e
m
∑
j
2
n
(
e
x
j
/
e
m
)
}
i
=
1
2
n
=
{
e
x
i
−
m
∑
j
2
n
e
x
j
−
m
}
i
=
1
2
n
safe\_softmax([x_1,x_2,...,x_{2n}])= \{{\frac {e^{x_i}/e^m}{\sum_j^{2n}(e^{x_j}/e^m)}}\}_{i=1}^{2n} = \{{\frac {e^{x_i-m}}{\sum_j^{2n}e^{x_j-m}}}\}_{i=1}^{2n}
safe_softmax([x1,x2,...,x2n])={∑j2n(exj/em)exi/em}i=12n={∑j2nexj−mexi−m}i=12n
Flash Attention Softmax
首先将safe_softmax计算过程拆解
对于 X = [ x 1 , x 2 , x 3 , . . . , x 2 n ] X=[x_1,x_2,x_3,...,x_{2n}] X=[x1,x2,x3,...,x2n]
- m a x ( X ) = m max(X) = m max(X)=m
- f u n ( X ) = [ e x 1 − m , e x 2 − m , . . . , e x 2 n − m ] fun(X) = [e^{x_1-m},e^{x_2-m},...,e^{x_{2n}-m}] fun(X)=[ex1−m,ex2−m,...,ex2n−m]
- s ( X ) = s u m ( f u n ( X ) ) s(X) = sum(fun(X)) s(X)=sum(fun(X))
- s a f t _ s o f t m a x ( X ) = f u n ( X ) s ( X ) saft\_softmax(X) = \frac{fun(X)}{s(X)} saft_softmax(X)=s(X)fun(X)
分块计算
将 X = [ x 1 , x 2 , x 3 , . . . , x 2 n ] X=[x_1,x_2,x_3,...,x_{2n}] X=[x1,x2,x3,...,x2n] 划分为 X 1 = [ x 1 , x 2 , x 3 , . . . , x n ] X_1=[x_1,x_2,x_3,...,x_{n}] X1=[x1,x2,x3,...,xn]和 X 2 = [ x n + 1 , x n + 2 , x n + 3 , . . . , x 2 n ] X_2=[x_{n+1},x_{n+2},x_{n+3},...,x_{2n}] X2=[xn+1,xn+2,xn+3,...,x2n]
对 X 1 , X 2 X_1,X_2 X1,X2分别开展计算:
- 分别计算 X 1 , X 2 X_1,X_2 X1,X2的最大值记为 m 1 = m a x ( X 1 ) m_1 = max(X_1) m1=max(X1) , m 2 = m a x ( X 2 ) m_2 = max(X_2) m2=max(X2)
- f 1 = f u n ( X 1 ) f_1 = fun(X_1) f1=fun(X1), f 2 = f u n ( X 2 ) f_2 = fun(X_2) f2=fun(X2)
- 比较两组数据中的最大值: m = m a x ( m 1 , m 2 ) m = max(m_1,m_2) m=max(m1,m2)
- 根据最大值 m m m更新两组数据的计算结果 f u n ( X ) = [ e m 1 − m f 1 , e m 2 − m f 2 ] fun(X) = [e^{m_1-m}f_1,e^{m_2-m}f_2] fun(X)=[em1−mf1,em2−mf2]
- 求和 s ( X ) = s u m ( f u n ( X ) ) s(X) = sum(fun(X)) s(X)=sum(fun(X))
- s a f e _ s o f t m a x ( X ) = f u n c ( X ) s ( X ) safe\_softmax(X) = \frac{func(X)}{s(X)} safe_softmax(X)=s(X)func(X)
简单举例,对于 X = [ 1 , 2 , 3 , 4 , 5 , 6 ] X=[1,2,3,4,5,6] X=[1,2,3,4,5,6]:
- X 1 = [ 1 , 2 , 3 ] , X 2 = [ 4 , 5 , 6 ] X_1=[1,2,3],X_2=[4,5,6] X1=[1,2,3],X2=[4,5,6]
- m 1 = m a x ( [ 1 , 2 , 3 ] ) = 3 , m 2 = m a x ( 4 , 5 , 6 ) = 6 m_1=max([1,2,3])=3,m_2=max(4,5,6)=6 m1=max([1,2,3])=3,m2=max(4,5,6)=6
- f 1 = [ e ( 1 − m 1 ) , e ( 2 − m 1 ) , e ( 3 − m 1 ) ] = [ e − 2 , e − 1 , e 0 ] f 2 = [ e ( 4 − m 2 ) , e ( 5 − m 2 ) , e ( 6 − m 2 ) ] = [ e − 2 , e − 1 , e 0 ] f_1=[e^{(1-m_1)},e^{(2-m_1)},e^{(3-m_1)}] = [e^{-2},e^{-1},e^{0}] \\ f_2=[e^{(4-m_2)},e^{(5-m_2)},e^{(6-m_2)}] = [e^{-2},e^{-1},e^{0}] f1=[e(1−m1),e(2−m1),e(3−m1)]=[e−2,e−1,e0]f2=[e(4−m2),e(5−m2),e(6−m2)]=[e−2,e−1,e0]
- m = m a x ( m 1 , m 2 ) = 6 m = max(m_1,m_2) = 6 m=max(m1,m2)=6
- f u n ( X ) = [ e m 1 − m f 1 , e m 2 − m f 2 ] = [ e − 3 f 1 , e 0 f 2 ] = [ e − 5 , e − 4 , e − 3 , e − 2 , e − 1 , e 0 ] fun(X)=[e^{m_1-m}f_1,e^{m_2-m}f_2]=[e^{-3}f_1,e^0f_2]\\ =[e^{-5},e^{-4},e^{-3},e^{-2},e^{-1},e^{0}] fun(X)=[em1−mf1,em2−mf2]=[e−3f1,e0f2]=[e−5,e−4,e−3,e−2,e−1,e0]
- s ( X ) = s u m ( f u n ( X ) ) = e − 5 + e − 4 + e − 3 + e − 2 + e − 1 + e 0 s(X) = sum(fun(X)) = e^{-5}+e^{-4}+e^{-3}+e^{-2}+e^{-1}+e^{0} s(X)=sum(fun(X))=e−5+e−4+e−3+e−2+e−1+e0
- s a f e _ s o f t m a x ( X ) = f u n c ( X ) s ( X ) safe\_softmax(X) = \frac{func(X)}{s(X)} safe_softmax(X)=s(X)func(X) 此步计算结果与safe_softmax一致