快速矩阵乘法的研究
最近的工作主要在于深度学习框架的性能优化。深度学习框架在工程的优化(内存池、SIMD、汇编、GPU、DSP等等)做到接近极限之后,突破点便集中于算法。
深度学习的性能瓶颈主要在于卷积,卷积的运算方法主要是通过 Im2Col / Winograd / FFT 转化为矩阵乘,完成矩阵乘法之后,再转化为目标结果。
深度学习框架的输入是算法工程产出的网络模型,而目前网络模型都渐渐地转变为 mobilenet 那样 1x1 convolution + depthwise 的形式,在精度几乎无损的情况下,既减少了计算量,又减少了模型体积。而这类网络模型,都以 1x1 卷积为主要耗时点。
对 1x1 卷积而言,其本身就是一个矩阵乘法,FFT / Winograd 等卷积算法已经失去价值,因此研读了一些矩阵乘法相关的论文,整理如下。
传统矩阵乘算法
定义
在 1968 年之前,矩阵乘算法只有按定义实现的传统算法,:
设:
A
=
(
a
11
a
12
.
.
.
a
21
a
22
.
.
.
.
.
.
.
.
.
.
.
.
a
n
1
a
n
2
.
.
.
)
B
=
(
b
11
b
12
.
.
.
b
21
b
22
.
.
.
.
.
.
.
.
.
.
.
.
b
n
1
b
n
2
.
.
.
)
A=\begin{pmatrix} a_{11} &a_{12} &... \\ a_{21} &a_{22} &... \\ ... & ... & ... \\ a_{n1} & a_{n2} & ... \\ \end{pmatrix} B=\begin{pmatrix} b_{11} &b_{12} &... \\ b_{21} &b_{22} &... \\ ... & ... & ... \\ b_{n1} & b_{n2} & ... \\ \end{pmatrix}
A=⎝⎜⎜⎛a11a21...an1a12a22...an2............⎠⎟⎟⎞B=⎝⎜⎜⎛b11b21...bn1b12b22...bn2............⎠⎟⎟⎞
AB 为其乘积,则:
[
A
B
]
p
q
=
∑
i
=
1
n
a
p
i
b
i
q
[AB]_{pq} = \sum_{i=1}^{n}a_{pi}b_{iq}
[AB]pq=i=1∑napibiq
很明显,它是一个 n 3 n^3 n3复杂度的算法,需要 n 3 n^3 n3 次乘法和 n 3 − n 2 n^3-n^2 n3−n2次加法。
矩阵乘表示
设 C = A B C = AB C=AB,A 为 e ∗ l e*l e∗l的矩阵,B 为 l ∗ h l*h l∗h的矩阵,则称这个矩阵乘是一个 [ e , l , h ] [e, l, h] [e,l,h] 的矩阵乘。
快速矩阵乘法的初步探索
Winograd 算法
请注意,这个不是我们通常所说的卷积优化算法,只是同一个人(Winograd大神)在 1968 年提出一种减少乘法数的矩阵乘算法。
其思路是通过两次
n
2
n^2
n2 的乘法预处理,将规模大的矩阵乘法减少一半,但相应的加法增加一半。为了说明简单,这里假定
n
n
n为偶数。
θ
p
=
∑
j
=
1
⌊
n
/
2
⌋
(
a
p
,
2
j
−
1
a
p
,
2
j
)
γ
q
=
∑
j
=
1
⌊
n
/
2
⌋
(
b
2
j
−
1
,
q
b
2
j
,
q
)
[
A
B
]
p
q
=
∑
j
=
1
⌊
n
/
2
⌋
(
a
p
,
2
j
−
1
+
b
2
j
,
q
)
(
a
p
,
2
j
+
b
2
j
−
1
,
q
)
−
θ
p
−
γ
q
\theta_p = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(a_{p, 2j-1} a_{p, 2j}) \\\gamma_q = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(b_{2j-1, q}b_{2j, q}) \\ [AB]_{pq} = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(a_{p, 2j-1}+b_{2j, q})(a_{p, 2j}+b_{2j-1, q}) - \theta_p - \gamma_q
θp=j=1∑⌊n/2⌋(ap,2j−1ap,2j)γq=j=1∑⌊n/2⌋(b2j−1,qb2j,q)[AB]pq=j=1∑⌊n/2⌋(ap,2j−1+b2j,q)(ap,2j+b2j−1,q)−θp−γq
这个算法没有降低矩阵乘法的阶(还是 n 3 n^3 n3),只是以廉价计算(加法)替代昂贵运算(乘法),需要根据具体的硬件去判断是否可应用。ARM 架构的 CPU,对量化矩阵乘有帮助,但对浮点矩阵乘没有用。
Strassen 矩阵乘算法
Strassen 矩阵乘的思路是通过加减变换,将一个
[
2
,
2
,
2
]
[2, 2, 2]
[2,2,2]的矩阵乘法所用的乘法数由8降到7,并且递归使用,降低矩阵乘法的阶数:
n
3
n^3
n3变成
n
2.81
n^{2.81}
n2.81
A
=
(
a
11
a
12
a
21
a
22
)
B
=
(
b
11
b
12
b
21
b
22
)
A
B
=
(
c
11
c
12
c
21
c
22
)
A=\begin{pmatrix} a_{11} &a_{12} \\ a_{21} &a_{22} \\ \end{pmatrix} B=\begin{pmatrix} b_{11} &b_{12} \\ b_{21} &b_{22} \\ \end{pmatrix} AB=\begin{pmatrix} c_{11} &c_{12} \\ c_{21} &c_{22} \\ \end{pmatrix}
A=(a11a21a12a22)B=(b11b21b12b22)AB=(c11c21c12c22)
v 1 = ( a 11 + a 22 ) ( b 11 + b 22 ) v 2 = ( a 21 + a 22 ) ( b 11 ) v 3 = ( a 11 ) ( b 12 − b 22 ) v 4 = ( a 22 ) ( b 21 − b 11 ) v 5 = ( a 11 + a 12 ) ( b 22 ) v 6 = ( a 21 − a 11 ) ( b 11 + b 12 ) v 7 = ( a 12 − a 22 ) ( b 21 + b 22 ) v_1 = (a_{11}+a_{22})(b_{11}+b_{22})\\ v_2 = (a_{21}+a_{22})(b_{11})\\v_3 = (a_{11})(b_{12}-b_{22})\\v_4 = (a_{22})(b_{21}-b_{11})\\v_5 = (a_{11}+a_{12})(b_{22})\\v_6 = (a_{21}-a_{11})(b_{11}+b_{12})\\v_7 = (a_{12}-a_{22})(b_{21}+b_{22}) v1=(a11+a22)(b11+b22)v2=(a21+a22)(b11)v3=(a11)(b12−b22)v4=(a22)(b21−b11)v5=(a11+a12)(b22)v6=(a21−a11)(b11+b12)v7=(a12−a22)(b21+b22)
c 11 = v 1 + v 4 − v 5 + v 7 c 21 = v 2 + v 4 c 12 = v 3 + v 5 c 22 = v 1 + v 3 − v 2 + v 6 c_{11} = v_1+v_4-v_5+v_7\\c_{21} = v_2+v_4\\c_{12} = v_3+v_5\\c_{22} = v_1+v_3-v_2+v_6 c11=v1+v4−v5+v7c21=v2+v4c12=v3+v5c22=v1+v3−v2+v6
请注意,其中每个元素( a 11 , b 12 , c 22 a_{11}, b_{12}, c_{22} a11,b12,c22等等)不限于实数,可以是一个矩阵。因为矩阵乘法满足分配率与结合率。这样算法就有了脱离硬件的普适价值,因为矩阵加减的复杂度( n 2 n^2 n2)远低于矩阵乘( n 3 n^3 n3)
Winograd 在 Strassen 的基础上对它的算法进行了改进,减少了加减数(18->15),这个也成为最常用的 Strassen 矩阵乘法应用。
三线性表示
为了方便矩阵乘算法的研究,人们提出一种表示矩阵乘算法的形式,叫“Trilinear-form”,即三线性形式。
我们先以 Strassen 算法为例,它的三线性形式是:
∑
i
=
1
2
∑
j
=
1
2
∑
k
=
1
2
a
i
j
b
j
k
c
i
k
=
(
a
11
)
(
b
12
−
b
22
)
(
c
12
+
c
22
)
+
(
a
11
+
a
12
)
(
b
22
)
(
−
c
11
+
c
12
)
+
(
a
21
+
a
22
)
(
b
11
)
(
c
21
−
c
22
)
+
(
a
22
)
(
b
21
+
b
11
)
(
c
11
+
c
21
)
+
(
a
11
+
a
22
)
(
b
11
+
b
22
)
(
c
11
+
c
22
)
+
(
a
12
−
a
22
)
(
b
21
+
b
22
)
(
c
11
)
+
(
a
11
−
a
21
)
(
b
11
+
b
12
)
(
−
c
22
)
\sum_{i=1}^2\sum_{j=1}^2\sum_{k=1}^2 a_{ij}b_{jk}c_{ik} = (a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22})
∑i=12∑j=12∑k=12aijbjkcik=(a11)(b12−b22)(c12+c22)+(a11+a12)(b22)(−c11+c12)+(a21+a22)(b11)(c21−c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12−a22)(b21+b22)(c11)+(a11−a21)(b11+b12)(−c22)
怎么看这个公式呢,它其实是按 T r a c e ( A B C ) = A B Trace(ABC) = AB Trace(ABC)=AB 的原理去表示的。两个矩阵的乘积,等效于三个矩阵乘积的迹。在上面公式中,如果我们要算出 c 11 c_{11} c11 的解法,就将 c 11 c_{11} c11 设成 1,其他的 c 值, c 12 , c 21 , c 22 c_{12}, c_{21}, c_{22} c12,c21,c22 全设成 0 ,然后将对应的项相加即可。
这个算式总共有7项,这个 7 我们称之为 Rank (阶)
APA——矩阵乘算法的突破
APA,即 Any Precision Algorithm,是把矩阵乘法阶数继续往下降的重要思想,基本思路是先给出近似的矩阵乘法表达式,然后在多阶张量积之后转换为准确的矩阵乘法。
张量积
我们来看 Strassen 矩阵乘法的表达式:
λ
=
(
a
11
)
(
b
12
−
b
22
)
(
c
12
+
c
22
)
+
(
a
11
+
a
12
)
(
b
22
)
(
−
c
11
+
c
12
)
+
(
a
21
+
a
22
)
(
b
11
)
(
c
21
−
c
22
)
+
(
a
22
)
(
b
21
+
b
11
)
(
c
11
+
c
21
)
+
(
a
11
+
a
22
)
(
b
11
+
b
22
)
(
c
11
+
c
22
)
+
(
a
12
−
a
22
)
(
b
21
+
b
22
)
(
c
11
)
+
(
a
11
−
a
21
)
(
b
11
+
b
12
)
(
−
c
22
)
\lambda = (a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22})
λ=(a11)(b12−b22)(c12+c22)+(a11+a12)(b22)(−c11+c12)+(a21+a22)(b11)(c21−c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12−a22)(b21+b22)(c11)+(a11−a21)(b11+b12)(−c22)
对其平方:
λ
2
=
(
(
a
11
)
(
b
12
−
b
22
)
(
c
12
+
c
22
)
+
(
a
11
+
a
12
)
(
b
22
)
(
−
c
11
+
c
12
)
+
(
a
21
+
a
22
)
(
b
11
)
(
c
21
−
c
22
)
+
(
a
22
)
(
b
21
+
b
11
)
(
c
11
+
c
21
)
+
(
a
11
+
a
22
)
(
b
11
+
b
22
)
(
c
11
+
c
22
)
+
(
a
12
−
a
22
)
(
b
21
+
b
22
)
(
c
11
)
+
(
a
11
−
a
21
)
(
b
11
+
b
12
)
(
−
c
22
)
)
2
\lambda^2 = ((a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22}))^2
λ2=((a11)(b12−b22)(c12+c22)+(a11+a12)(b22)(−c11+c12)+(a21+a22)(b11)(c21−c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12−a22)(b21+b22)(c11)+(a11−a21)(b11+b12)(−c22))2
这是个多项式乘法,不难知
λ
2
\lambda^2
λ2 有
7
2
=
49
7^2=49
72=49 项,我们来看其中一项:
(
(
a
11
)
(
b
12
−
b
22
)
(
c
12
+
c
22
)
)
(
(
a
11
+
a
12
)
(
b
22
)
(
−
c
11
+
c
12
)
)
=
(
a
11
a
11
+
a
11
a
12
)
(
b
12
b
22
−
b
22
b
22
)
(
−
c
12
c
11
+
c
12
c
12
−
c
22
c
11
+
c
22
c
12
)
((a_{11})(b_{12}-b_{22})(c_{12}+c_{22}))((a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}))=(a_{11}a_{11}+a_{11}a_{12})(b_{12}b_{22}-b_{22}b_{22})(-c_{12}c_{11}+c_{12}c_{12}-c_{22}c_{11}+c_{22}c_{12})
((a11)(b12−b22)(c12+c22))((a11+a12)(b22)(−c11+c12))=(a11a11+a11a12)(b12b22−b22b22)(−c12c11+c12c12−c22c11+c22c12)
(依然是将a, b, c 分别组合在一起)
a
,
b
,
c
a, b, c
a,b,c间的相乘,如
a
11
a
12
a_{11}a_{12}
a11a12,我们将其替代为直和:
a
1112
a_{1112}
a1112,其含义可以这么理解,在
a
11
a_{11}
a11的区域(左上角)中,再划分为四块,取其
a
12
a_{12}
a12的区域(右上角)。
不难证明,我们通过这个多项式平方后得到的三线性形式,等效于一个
[
4
,
4
,
4
]
[4, 4, 4]
[4,4,4] 的矩阵乘法。
类似地,我们可以对矩阵乘法的三线性形式进行立方,n次方,以及两个不同的三线性形式乘积,这一系列操作可由“张量积”概括。
APA
Any Precision Algorithm(APA),即任意精度算法,通过在算式中引入一个可配置的实数 λ \lambda λ,得到更好的简化效果。
下面的式子近似用21项表示了一个 [ 3 , 3 , 3 ] [3, 3, 3] [3,3,3]的矩阵乘法
F 1 ( λ ) = ( a 11 + λ 2 a 12 ) ( λ 2 b 11 + b 21 ) c 11 + ( a 21 + λ 2 a 22 ) ( λ 2 b 12 + b 22 ) c 22 + ( a 31 + λ 2 a 32 ) ( λ 2 b 13 + b 23 ) c 33 − a 11 ( b 21 + b 31 ) ( c 11 + c 12 + c 13 ) − a 21 ( b 22 + b 32 ) ( c 21 + c 22 + c 23 ) − a 31 ( b 23 + b 33 ) ( c 31 + c 32 + c 33 ) + ( a 11 + λ 2 a 22 ) ( b 21 − λ b 12 ) c 12 + ( a 21 + λ 2 a 12 ) ( b 22 − λ b 11 ) c 21 + ( a 11 + λ 2 a 32 ) ( b 21 − λ b 13 ) c 13 + ( a 31 + λ 2 a 12 ) ( b 23 − λ b 11 ) c 31 + ( a 21 + λ 2 a 32 ) ( b 22 − λ b 13 ) c 23 + ( a 31 + λ 2 a 22 ) ( b 23 − λ b 12 ) c 32 + ( a 11 + λ 2 a 23 ) ( b 31 + λ b 12 ) ( c 12 + λ c 21 ) + ( a 21 + λ 2 a 13 ) ( b 32 + λ b 11 ) ( c 21 + λ c 12 ) + ( a 11 + λ 2 a 33 ) ( b 31 + λ b 13 ) ( c 13 + λ c 31 ) + ( a 31 + λ 2 a 13 ) ( b 33 + λ b 12 ) ( c 31 + λ c 13 ) + ( a 21 + λ 2 a 33 ) ( b 32 + λ b 13 ) ( c 23 + λ c 32 ) + ( a 31 + λ 2 a 23 ) ( b 33 + λ b 12 ) ( c 32 + λ c 23 ) + ( a 11 + λ 2 a 13 ) b 31 ( c 11 − λ c 31 − λ c 21 ) + ( a 21 + λ 2 a 23 ) b 32 ( c 22 − λ c 32 − λ c 12 ) + ( a 31 + λ 2 a 33 ) b 33 ( c 33 − λ c 13 − λ c 23 ) = λ 2 ( T r a c e ( A B C ) + λ G ( λ ) ) F_1(\lambda) = (a_{11}+\lambda^2a_{12})(\lambda^2b_{11}+b_{21})c_{11}\\+(a_{21}+\lambda^2a_{22})(\lambda^2b_{12}+b_{22})c_{22}+(a_{31}+\lambda^2a_{32})(\lambda^2b_{13}+b_{23})c_{33}-a_{11}(b_{21}+b_{31})(c_{11}+c_{12}+c_{13})-a_{21}(b_{22}+b_{32})(c_{21}+c_{22}+c_{23})-a_{31}(b_{23}+b_{33})(c_{31}+c_{32}+c_{33})+(a_{11}+\lambda^2a_{22})(b_{21}-\lambda b_{12})c_{12}+(a_{21}+\lambda^2a_{12})(b_{22}-\lambda b_{11})c_{21}+(a_{11}+\lambda^2a_{32})(b_{21}-\lambda b_{13})c_{13}+(a_{31}+\lambda^2a_{12})(b_{23}-\lambda b_{11})c_{31}+(a_{21}+\lambda^2a_{32})(b_{22}-\lambda b_{13})c_{23}+(a_{31}+\lambda^2a_{22})(b_{23}-\lambda b_{12})c_{32}+(a_{11}+\lambda^2a_{23})(b_{31}+\lambda b_{12})(c_{12}+\lambda c_{21})+(a_{21}+\lambda^2a_{13})(b_{32}+\lambda b_{11})(c_{21}+\lambda c_{12})+(a_{11}+\lambda^2a_{33})(b_{31}+\lambda b_{13})(c_{13}+\lambda c_{31})+(a_{31}+\lambda^2a_{13})(b_{33}+\lambda b_{12})(c_{31}+\lambda c_{13})+(a_{21}+\lambda^2a_{33})(b_{32}+\lambda b_{13})(c_{23}+\lambda c_{32})+(a_{31}+\lambda^2a_{23})(b_{33}+\lambda b_{12})(c_{32}+\lambda c_{23})+(a_{11}+\lambda^2a_{13})b_{31}(c_{11}-\lambda c_{31}-\lambda c_{21})+(a_{21}+\lambda^2a_{23})b_{32}(c_{22}-\lambda c_{32}-\lambda c_{12})+(a_{31}+\lambda^2a_{33})b_{33}(c_{33}-\lambda c_{13}-\lambda c_{23}) = \lambda^2 (Trace(ABC)+\lambda G(\lambda)) F1(λ)=(a11+λ2a12)(λ2b11+b21)c11+(a21+λ2a22)(λ2b12+b22)c22+(a31+λ2a32)(λ2b13+b23)c33−a11(b21+b31)(c11+c12+c13)−a21(b22+b32)(c21+c22+c23)−a31(b23+b33)(c31+c32+c33)+(a11+λ2a22)(b21−λb12)c12+(a21+λ2a12)(b22−λb11)c21+(a11+λ2a32)(b21−λb13)c13+(a31+λ2a12)(b23−λb11)c31+(a21+λ2a32)(b22−λb13)c23+(a31+λ2a22)(b23−λb12)c32+(a11+λ2a23)(b31+λb12)(c12+λc21)+(a21+λ2a13)(b32+λb11)(c21+λc12)+(a11+λ2a33)(b31+λb13)(c13+λc31)+(a31+λ2a13)(b33+λb12)(c31+λc13)+(a21+λ2a33)(b32+λb13)(c23+λc32)+(a31+λ2a23)(b33+λb12)(c32+λc23)+(a11+λ2a13)b31(c11−λc31−λc21)+(a21+λ2a23)b32(c22−λc32−λc12)+(a31+λ2a33)b33(c33−λc13−λc23)=λ2(Trace(ABC)+λG(λ))
当 λ \lambda λ趋于无穷小时,其误差也趋于无穷小,因此我们可以设定任意的精度去使用它,这就是 APA 的由来。
对于 APA 算法,多项式的个数我们称之为 Border Rank,上述算式表示了一个 [ 3 , 3 , 3 ] [3, 3, 3] [3,3,3]的矩阵乘法,在 λ 3 \lambda ^3 λ3的基础上分出误差,我们称之为一个降解: [ 3 , 3 , 3 ] ⊴ 3 21 [3, 3, 3] \unlhd_3 21 [3,3,3]⊴321
现在我们来看怎么把上面的 APA 算法变成准确算法。
直观的做法就是把 λ 2 \lambda^2 λ2项取出来,如: ( a 11 + λ 2 a 12 ) ( λ 2 b 11 + b 21 ) c 11 (a_{11}+\lambda^2a_{12})(\lambda^2b_{11}+b_{21})c_{11} (a11+λ2a12)(λ2b11+b21)c11,取出 λ 2 a 11 b 11 c 11 + λ 2 a 12 b 21 c 11 \lambda^2a_{11}b_{11}c_{11}+\lambda^2a_{12}b_{21}c_{11} λ2a11b11c11+λ2a12b21c11,代价就是增加了多项式,不难证明,我们最多会增加到 2 ( 2 + 1 ) / 2 = 3 2(2+1)/2=3 2(2+1)/2=3倍的多项式个数。
无疑,这样做肯定亏了, 3 ∗ 21 = 63 > 3 ∗ 3 ∗ 3 = 27 3*21=63 > 3*3*3=27 3∗21=63>3∗3∗3=27,我们需要施个魔法,就是张量积。
对上面APA 算法进行n次张量积之后,我们可以得到 3 n 3^n 3n大小的矩阵乘算法的降解: [ 3 n , 3 n , 3 n ] ⊴ 2 n + 1 2 1 n [3^n, 3^n, 3^n] \unlhd_{2n+1} 21^n [3n,3n,3n]⊴2n+121n
这时候我们再来取,就不一样了,其阶数变成了:
n
(
2
n
+
1
)
2
1
n
n(2n+1)21^n
n(2n+1)21n
很明显,当 n 足够大时,
n
(
2
n
+
1
)
n(2n+1)
n(2n+1) 和指数项相比可忽略,这样我们就得到了更好的准确算法,其阶数为:
3
l
n
(
21
)
/
l
n
(
27
)
≈
2.77
3ln(21)/ln(27)\approx2.77
3ln(21)/ln(27)≈2.77
下篇内容:
1、组合矩阵乘
2、渐近和定理
3、Strassen构造
4、Coppersmith–Winograd 算法