AlphaTensor:发现更快的矩阵乘法

参考文献:

  1. Discovering faster matrix multiplication algorithms with reinforcement learning:https://doi.org/10.1038/s41586-022-05172-4

Tensor

矩阵 A ∈ F n × m A \in \mathbb F^{n \times m} AFn×m B ∈ F m × p B \in \mathbb F^{m \times p} BFm×p,矩阵乘积 ( A , B ) ↦ C : = A B (A,B) \mapsto C:=AB (A,B)C:=AB 是双线性的(bilinear),因此可以表示为 3 3 3维的张量(Tensor):
T n , m , p = [ t a b c ] ∈ { 0 , 1 } n m , m p , n p T_{n,m,p} = [t_{abc}] \in \{0,1\}^{nm,mp,np} Tn,m,p=[tabc]{0,1}nm,mp,np
这个 T n , m , p T_{n,m,p} Tn,m,p可以指定读写位置。先把 A , B , C A,B,C A,B,C展开为一维向量 A ′ , B ′ , C ′ A',B',C' A,B,C,如果 t a b c = 1 t_{abc}=1 tabc=1,那么 C c ′ C'_c Cc的值的加和项中就有 A a ′ ⋅ B b ′ A'_a \cdot B'_b AaBb

在这里插入图片描述

我们把 T n , m , p T_{n,m,p} Tn,m,p做分解,得到 U ∈ Z n m × R , V ∈ Z m p × R , W ∈ Z n p × R U \in \mathbb Z^{nm \times R},V \in \mathbb Z^{mp \times R},W \in \mathbb Z^{np \times R} UZnm×R,VZmp×R,WZnp×R,使得
T n , m , p = ∑ r = 1 R u ( r ) ⊗ v ( r ) ⊗ w ( r ) T_{n,m,p} = \sum_{r=1}^R u^{(r)} \otimes v^{(r)} \otimes w^{(r)} Tn,m,p=r=1Ru(r)v(r)w(r)

其中的 ⊗ \otimes 是列向量外积(outer product)或者叫张量积(tensor product)

可以写出根据 T n , m , p T_{n,m,p} Tn,m,p的分解 U , V , W U,V,W U,V,W控制的矩阵乘法:

在这里插入图片描述

如果分解后,列数 R R R越小,那么运算时的小矩阵乘法的数量就不大于 R R R。 著名的 Strassen’s algorithm,它就是找到了张量 T 2 , 2 , 2 T_{2,2,2} T2,2,2 R = 7 R=7 R=7的分解。

AlphaTensor

这篇文章利用 Neural network 搭建了一个模型 AlphaTensor,发现了一系列的对张量 T n , m , p T_{n,m,p} Tn,m,p的更优分解:

在这里插入图片描述

具体的 U , V , W U,V,W U,V,W分解数据在文章附录,诸位可以自己去找。

另外,对于特殊的矩阵也有更优的算法。比如反对称阵(skew-symmetric matrix)和向量的乘法,仅需要 ( n − 1 ) ( n + 2 ) / 2 (n-1)(n+2)/2 (n1)(n+2)/2个乘法运算:

在这里插入图片描述

上述的张量分解,还可以被应用于任意的双线性运算上,比如:循环卷积(多项式乘法,DFT),AlphaTensor 找到了对 n = 2 , 4 , 8 n=2,4,8 n=2,4,8的加速。

  • 6
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值