快速矩阵乘法的研究
渐近和不等式(Asymptotic sum inequality)
有了APA之后,我们突破了Strassen 算法的阈值,但希望进一步往下降时,总是构造不出来,遇到了瓶颈。这个时候,人们发现了一条新的道路。
同时计算多个矩阵乘法
对于 C = A B C=AB C=AB 这个矩阵乘法,我们已经找不到更好的构造算法,但如果同时计算 C 1 = A 1 B 1 , C 2 = A 2 B 2 C_1=A_1B_1, C_2=A_2B_2 C1=A1B1,C2=A2B2,我们会不会有好方法呢?
答案是有的, 请看下面的构造:
设:
ϕ
=
∑
i
=
1
e
∑
j
=
1
l
a
i
b
j
c
j
,
i
+
∑
i
=
1
e
−
1
∑
j
=
1
l
−
1
X
i
,
j
Y
i
,
j
Z
\phi = \sum_{i=1}^e\sum_{j=1}^la_ib_jc_{j,i}+\sum_{i=1}^{e-1}\sum_{j=1}^{l-1}X_{i,j}Y_{i,j}Z
ϕ=i=1∑ej=1∑laibjcj,i+i=1∑e−1j=1∑l−1Xi,jYi,jZ
这条算式同时计算了两个不相关的矩阵乘法
C
=
A
B
,
Z
=
X
Y
C=AB, Z=XY
C=AB,Z=XY,一个尺寸是
[
e
,
1
,
l
]
[e, 1, l]
[e,1,l],另一个是
[
1
,
(
e
−
1
)
(
l
−
1
)
,
1
]
[1, (e-1)(l-1), 1]
[1,(e−1)(l−1),1],这种计算我们定义为
[
e
,
1
,
l
]
⊕
[
1
,
(
e
−
1
)
(
l
−
1
)
,
1
]
[e, 1, l]\oplus [1, (e-1)(l-1) ,1]
[e,1,l]⊕[1,(e−1)(l−1),1]
不难看出,原始需要的乘法数为
e
l
+
(
e
−
1
)
(
l
−
1
)
el+(e-1)(l-1)
el+(e−1)(l−1)
现在我们来构造快速算法,令:
X
i
,
l
=
0
,
X
e
,
j
=
−
∑
i
=
1
e
−
1
X
i
,
j
,
Y
i
,
l
=
−
∑
j
=
1
l
−
1
Y
i
,
j
X_{i,l}=0, X_{e, j}=-\sum_{i=1}^{e-1}X_{i,j}, Y_{i, l}=-\sum_{j=1}^{l-1}Y_{i,j}
Xi,l=0,Xe,j=−i=1∑e−1Xi,j,Yi,l=−j=1∑l−1Yi,j
则有:
F
′
=
∑
i
=
1
e
∑
j
=
1
l
(
a
i
+
λ
X
i
,
j
)
(
b
j
+
λ
Y
i
,
j
)
(
λ
2
c
j
,
i
+
Z
)
−
∑
i
=
1
e
a
i
∑
j
=
1
l
b
j
Z
=
λ
2
ϕ
+
λ
3
G
(
λ
)
F'=\sum_{i=1}^e\sum_{j=1}^l(a_i+\lambda X_{i,j})(b_j+\lambda Y_{i,j})(\lambda ^2 c_{j, i}+Z)-\sum_{i=1}^ea_i\sum_{j=1}^lb_jZ=\lambda ^2 \phi +\lambda ^3 G(\lambda)
F′=i=1∑ej=1∑l(ai+λXi,j)(bj+λYi,j)(λ2cj,i+Z)−i=1∑eaij=1∑lbjZ=λ2ϕ+λ3G(λ)
这个式子为刚才的联合矩阵乘给出了一个
e
l
+
1
el+1
el+1的 Border Rank,也即:
[
e
,
1
,
l
]
⊕
[
1
,
(
e
−
1
)
(
l
−
1
)
,
1
]
⊴
3
(
e
l
+
1
)
[e, 1, l]\oplus [1, (e-1)(l-1) ,1]\unlhd_3 (el+1)
[e,1,l]⊕[1,(e−1)(l−1),1]⊴3(el+1),若取
e
=
4
,
l
=
4
e=4, l=4
e=4,l=4,这一次展开就减少了 32% 的乘法量,十分诱人。
但这对矩阵乘法的优化有什么作用呢,请看下面的定理
Asymptotic Sum Inequality
定理:设
ω
\omega
ω为矩阵乘法的阶,假设有一个Border Rank 为 r 的算法
ϕ
\phi
ϕ 能够同时计算 s 个矩阵乘法
[
e
1
,
h
1
,
l
1
]
,
[
e
2
,
h
2
,
l
2
]
.
.
.
[
e
s
,
h
s
,
l
s
]
[e_1, h_1, l_1], [e_2, h_2, l_2]...[e_s, h_s, l_s]
[e1,h1,l1],[e2,h2,l2]...[es,hs,ls],也即
⨁
i
=
1
s
[
e
i
,
h
i
,
l
i
]
⊴
q
r
\bigoplus_{i=1}^s [e_i,h_i,l_i] \unlhd_q r
⨁i=1s[ei,hi,li]⊴qr,那么:
∑
i
=
1
s
(
e
i
h
i
l
i
)
ω
/
3
≤
r
\sum_{i=1}^s(e_ih_il_i)^{\omega /3}\le r
i=1∑s(eihili)ω/3≤r
定理的严格证明可自行查阅文献,这里力求简单地说明思路。
首先,我们不难看出,能够以 r 个多项式计算多个矩阵乘的算法,可以把一些项设成零,推导出一些用小于等于r个多项式计算其中一个或部分矩阵乘的算法。
⨁
i
=
1
s
[
e
i
,
h
i
,
l
i
]
⊴
q
r
⟶
[
e
i
,
h
i
,
l
i
]
⊴
q
r
\bigoplus_{i=1}^s [e_i,h_i,l_i] \unlhd_q r \longrightarrow [e_i, h_i, l_i] \unlhd_q r
i=1⨁s[ei,hi,li]⊴qr⟶[ei,hi,li]⊴qr
比如上节的 F ′ = ∑ i = 1 e ∑ j = 1 l ( a i + λ X i , j ) ( b j + λ Y i , j ) ( λ 2 c j , i + Z ) − ∑ i = 1 e a i ∑ j = 1 l b j Z = λ 2 ϕ + λ 3 G ( λ ) F'=\sum_{i=1}^e\sum_{j=1}^l(a_i+\lambda X_{i,j})(b_j+\lambda Y_{i,j})(\lambda ^2 c_{j, i}+Z)-\sum_{i=1}^ea_i\sum_{j=1}^lb_jZ=\lambda ^2 \phi +\lambda ^3 G(\lambda) F′=i=1∑ej=1∑l(ai+λXi,j)(bj+λYi,j)(λ2cj,i+Z)−i=1∑eaij=1∑lbjZ=λ2ϕ+λ3G(λ)
我们在 a i , b j , c j , i a_i, b_j, c_{j,i} ai,bj,cj,i前面都乘个0,便成了 Z = X Y Z=XY Z=XY的算法,类似可得 C = A B C=AB C=AB。如此进行处理,多项式的个数只会减少,不会增加。
然后,我们进行张量积,求其n次幂,方便起见只考虑两个矩阵乘算法。
先平方(n=2),来看看发生了什么。
(
[
e
1
,
h
1
,
l
1
]
⊕
[
e
2
,
h
2
,
l
2
]
)
2
=
[
e
1
2
,
h
1
2
,
l
1
2
]
⊕
[
e
2
2
,
h
2
2
,
l
2
2
]
⊕
[
e
1
e
2
,
h
1
h
2
,
l
1
l
2
]
⊕
[
e
1
e
2
,
h
1
h
2
,
l
1
l
2
]
([e_1, h_1, l_1]\oplus [e_2,h_2,l_2])^2 = [e_1^2,h_1^2,l_1^2]\oplus [e_2^2,h_2^2,l_2^2] \oplus [e_1e_2,h_1h_2,l_1l_2] \oplus [e_1e_2,h_1h_2,l_1l_2]
([e1,h1,l1]⊕[e2,h2,l2])2=[e12,h12,l12]⊕[e22,h22,l22]⊕[e1e2,h1h2,l1l2]⊕[e1e2,h1h2,l1l2]
张量积平方后,我们同时计算了四个独立的矩阵乘法,其中有2个矩阵乘是形状一样的,如果展开较多,比如达到7个,我们能将其拼起来,按上一章所述的 Strassen 算法构造一个单独的矩阵乘。
定理的证明就是在n足够大时,找一个 μ \mu μ,并令 P = ( C n μ ) 1 / ω P=(C_n^{\mu})^{1/\omega} P=(Cnμ)1/ω,然后把这组矩阵乘拼成一个 [ P e 1 μ e 2 n − μ , P h 1 μ h 2 n − μ , P l 1 μ l 2 n − μ ] [Pe_1^{\mu}e_2^{n-\mu}, Ph_1^{\mu}h_2^{n-\mu}, Pl_1^{\mu}l_2^{n-\mu}] [Pe1μe2n−μ,Ph1μh2n−μ,Pl1μl2n−μ]的矩阵乘,这个矩阵乘可以用不大于 ( ( q − 1 ) N + 1 ) 2 r n ((q-1)N+1)^2r^n ((q−1)N+1)2rn的乘法算出来,进一步证明出最终等式。
在定理的证明过程中,我们不难发现这个矩阵乘算法构造是一个迭代过程:
我们先以当前最好的矩阵乘算法,找到一组
n
,
μ
,
P
n, \mu, P
n,μ,P,拼出一个矩阵乘算法,然后再用这个矩阵乘算法重新确定
n
,
μ
,
P
n, \mu, P
n,μ,P,再拼,直到不能下降为止。
通过这个定理,在 e = 4 , l = 4 e=4, l=4 e=4,l=4时,可证明出 ω < 2.5479 \omega < 2.5479 ω<2.5479,大大地前进了。
Strassen 构造
前面我们介绍了如何用多个独立矩阵乘直和去构造矩阵乘法。而 Strassen 通过 Matrix To Scalar 定理,用非独立矩阵乘构造出更好的结果,将矩阵乘的阶数降到了2.48。
Matrix To Scalar
定理:一个
[
N
,
N
,
N
]
[N, N, N]
[N,N,N]的矩阵乘法,可以近似地(APA)计算
3
N
2
/
4
3N^2/4
3N2/4 个独立的实数乘积。
初看这定理会感觉莫名奇妙,本身矩阵乘法包含了
N
3
N^3
N3项,去计算
3
N
2
/
4
3N^2/4
3N2/4个实数不是很亏么。
其实是这样的,矩阵乘法
N
3
N^3
N3个项并不是互相独立的,以
[
2
,
2
,
2
]
[2, 2, 2]
[2,2,2]为例,计算出来的是:
a
11
b
11
c
11
+
a
12
b
21
c
11
+
a
11
b
12
c
21
+
a
12
b
22
c
21
+
a
21
b
11
c
12
+
a
22
b
21
c
12
+
a
21
b
12
c
22
+
a
22
b
22
c
22
a_{11}b_{11}c_{11}+a_{12}b_{21}c_{11}+a_{11}b_{12}c_{21}+a_{12}b_{22}c_{21}+a_{21}b_{11}c_{12}+a_{22}b_{21}c_{12}+a_{21}b_{12}c_{22}+a_{22}b_{22}c_{22}
a11b11c11+a12b21c11+a11b12c21+a12b22c21+a21b11c12+a22b21c12+a21b12c22+a22b22c22
其中
a
11
b
11
c
11
a_{11}b_{11}c_{11}
a11b11c11 和
a
12
b
21
c
11
a_{12}b_{21}c_{11}
a12b21c11 共同含有
c
11
c_{11}
c11,
a
11
b
11
c
11
a_{11}b_{11}c_{11}
a11b11c11 和
a
11
b
12
c
21
a_{11}b_{12}c_{21}
a11b12c21 共同含有
a
11
a_{11}
a11,也即不互相独立。
这个定理就是找到一个构造方法,筛出 3 N 2 / 4 3N^2/4 3N2/4个互相独立的乘数出来。
设定
g
=
[
3
(
N
+
1
)
/
2
]
g=[3(N+1)/2]
g=[3(N+1)/2],然后按如下方式给
[
N
,
N
,
N
]
[N, N, N]
[N,N,N]矩阵乘法中的各个乘数乘以
λ
\lambda
λ:
∑
i
=
1
N
∑
j
=
1
N
∑
k
=
1
N
(
x
i
,
j
λ
i
2
+
2
i
j
)
(
y
j
,
k
λ
j
2
+
2
j
(
k
−
g
)
)
(
z
k
,
i
λ
(
k
−
g
)
2
+
2
(
k
−
g
)
i
)
=
∑
i
+
j
+
k
=
g
x
i
,
j
y
j
,
k
z
k
,
i
+
O
(
λ
)
\sum_{i=1}^N\sum_{j=1}^N\sum_{k=1}^N(x_{i,j}\lambda^{i^2+2ij})(y_{j,k}\lambda^{j^2+2j(k-g)})(z_{k,i}\lambda^{(k-g)^2+2(k-g)i}) = \sum_{i+j+k=g}x_{i,j}y_{j,k}z_{k,i} +O(\lambda)
i=1∑Nj=1∑Nk=1∑N(xi,jλi2+2ij)(yj,kλj2+2j(k−g))(zk,iλ(k−g)2+2(k−g)i)=i+j+k=g∑xi,jyj,kzk,i+O(λ)
如此,构造出的等式右边是那 3 N 2 / 4 3N^2/4 3N2/4个 x , y , z x, y, z x,y,z均不相同的乘数和。
Strassen 构造
用如下式子进行一个初始的APA构造:
∑
i
=
1
q
(
x
0
+
λ
x
i
)
(
y
0
+
λ
y
i
)
(
x
i
λ
−
1
)
+
x
0
y
0
(
−
∑
i
=
1
q
z
i
λ
−
1
)
=
∑
i
=
1
q
(
x
i
y
0
z
i
+
x
0
y
i
z
i
)
\sum_{i=1}^q(x_0+\lambda x_i)(y_0+\lambda y_i)(x_i\lambda ^{-1})+x_0y_0(-\sum_{i=1}^qz_i\lambda^{-1}) = \sum_{i=1}^q(x_iy_0z_i+x_0y_iz_i)
i=1∑q(x0+λxi)(y0+λyi)(xiλ−1)+x0y0(−i=1∑qziλ−1)=i=1∑q(xiy0zi+x0yizi)
上面的算子用 q + 1 q+1 q+1个乘数计算了 2 q 2q 2q个乘数,但右边的式子并非矩阵乘法,也不是两个互相独立的矩阵乘法和。
我们将
x
0
x_0
x0和
{
x
1
,
x
2
,
.
.
.
,
x
q
}
\{x_1, x_2, ..., x_q\}
{x1,x2,...,xq}各视为一个数
X
0
,
X
1
X_0,X_1
X0,X1,类似地
{
y
1
,
y
2
,
.
.
.
,
y
q
}
\{y_1, y_2, ..., y_q\}
{y1,y2,...,yq}和
{
z
1
,
z
2
,
.
.
.
,
z
q
}
\{z_1, z_2, ..., z_q\}
{z1,z2,...,zq}也视为一个数,上面的等式右边就成了一个
[
1
,
2
,
1
]
[1, 2, 1]
[1,2,1]的"矩阵乘法":
X
1
Y
0
Z
1
+
X
0
Y
1
Z
1
X_1Y_0Z_1+X_0Y_1Z_1
X1Y0Z1+X0Y1Z1
我们将 x, y, z 轮换两次,构建额外的两个式子:
X
1
Y
0
Z
1
+
X
1
Y
1
Z
0
,
[
1
,
1
,
2
]
X_1Y_0Z_1 + X_1Y_1Z_0, [1, 1, 2]
X1Y0Z1+X1Y1Z0,[1,1,2]
X
0
Y
1
Z
1
+
X
1
Y
1
Z
0
,
[
2
,
1
,
1
]
X_0Y_1Z_1 + X_1Y_1Z_0, [2, 1, 1]
X0Y1Z1+X1Y1Z0,[2,1,1]
这两个式子和原来那个式子作张量积,我们得到一个
[
2
,
2
,
2
]
[2, 2, 2]
[2,2,2]的"矩阵乘法",这个矩阵乘法需要
(
q
+
1
)
3
(q+1)^3
(q+1)3个多项式:
X
1
,
1
,
0
Y
0
,
0
,
1
Z
1
,
1
,
1
+
X
1
,
1
,
1
Y
0
,
0
,
1
Z
1
,
1
,
0
+
X
1
,
1
,
0
Y
0
,
1
,
1
Z
1
,
0
,
1
+
X
1
,
1
,
1
Y
0
,
1
,
1
Z
1
,
0
,
0
+
X
0
,
1
,
0
Y
1
,
0
,
1
Z
1
,
1
,
1
+
X
0
,
1
,
1
Y
1
,
0
,
1
Z
1
,
1
,
0
+
X
0
,
1
,
0
Y
1
,
1
,
1
Z
1
,
0
,
1
+
X
0
,
1
,
1
Y
1
,
1
,
1
Z
1
,
0
,
0
X_{1, 1, 0}Y_{0,0,1}Z_{1, 1, 1} + X_{1, 1, 1}Y_{0,0,1}Z_{1, 1, 0}+X_{1, 1, 0}Y_{0,1,1}Z_{1, 0, 1} + X_{1, 1, 1}Y_{0,1,1}Z_{1, 0, 0}+X_{0, 1, 0}Y_{1,0,1}Z_{1, 1, 1} + X_{0, 1, 1}Y_{1,0,1}Z_{1, 1, 0}+X_{0,1,0}Y_{1,1,1}Z_{1, 0, 1}+X_{0,1,1}Y_{1,1,1}Z_{1,0,0}
X1,1,0Y0,0,1Z1,1,1+X1,1,1Y0,0,1Z1,1,0+X1,1,0Y0,1,1Z1,0,1+X1,1,1Y0,1,1Z1,0,0+X0,1,0Y1,0,1Z1,1,1+X0,1,1Y1,0,1Z1,1,0+X0,1,0Y1,1,1Z1,0,1+X0,1,1Y1,1,1Z1,0,0
这个“矩阵乘法中”,每个项都是一个 [ e , h , l ] , e ∗ h ∗ l = q 3 [e, h, l], e*h*l=q^3 [e,h,l],e∗h∗l=q3的矩阵乘法。
然后在这个张量积的基础上,将其作N次幂,就是一个
[
2
N
,
2
N
,
2
N
]
[2^N, 2^N, 2^N]
[2N,2N,2N]的"矩阵乘法",对应的需要
(
q
+
3
)
3
N
(q+3)^{3N}
(q+3)3N个多项式,这时候应用 Matrix to Scalar,将其多项式独立出来,就成了
3
/
4
∗
2
2
N
3/4*2^{2N}
3/4∗22N个满足
e
∗
h
∗
l
=
q
3
N
e*h*l=q^{3N}
e∗h∗l=q3N的矩阵乘法和,应用 Asymptotic Sum Inequality,可得:
(
3
/
4
)
2
2
N
q
N
ω
≤
(
q
+
1
)
3
N
(3/4)2^{2N}q^{N\omega} \leq (q+1)^{3N}
(3/4)22NqNω≤(q+1)3N
令N趋于无穷大,可得:
2
2
q
ω
<
=
(
q
+
1
)
3
2^2q^{\omega} <= (q+1)^3
22qω<=(q+1)3
设 q = 5,便得到了
ω
<
=
2.4785
\omega <= 2.4785
ω<=2.4785
Coppersmith 和 Winograd 用另外一种方法,将矩阵乘阶数降到了 2.37,这部分内容比想象中要复杂,因此放到最后一篇介绍。
参考文献
- Matrix Multiplication Via Arithmetic Progressions (DON COPPERSMITH and SIIMUEL WINOGRAD)
- On the Complexity of Matrix Multiplication (Andrew James Stothers)