矩阵相乘的strassen算法_详解矩阵乘法中的Strassen算法

本文详细介绍了Strassen算法,一种用于矩阵乘法的优化算法,旨在减少乘法操作。Strassen算法由Volker Strassen在1969年提出,将矩阵乘法的时间复杂度降低到低于常规的O(n^3)。文章首先解释了矩阵乘法的基本概念,然后深入探讨Strassen算法的原理,包括矩阵的分块、递归计算和合并步骤。最后,文章提到了Strassen算法的代码实现,包括分治思想的应用和效率判断。
摘要由CSDN通过智能技术生成

By LongLuo

机器学习中需要训练大量数据,涉及大量复杂运算,例如卷积、矩阵等。这些复杂运算不仅多,而且每次计算的数据量很大,如果能针对这些运算进行优化,可以大幅提高性能。

一、矩阵乘法

假设

的矩阵,

的矩阵,那么称

的矩阵

为矩阵

的乘积,记作

,称为矩阵积(matrix product)。

其中矩阵

中的第

行第

列元素可以表示为:

如下图所示:

矩阵乘法

假如在矩阵

和矩阵

中,

,那么完成

需要多少次乘法呢?

对于每一个行向量

,总共有

行;

对于每一个列向量

,总共有

列;

计算它们的内积,总共有

次乘法计算。

综合可以看出,矩阵乘法的算法复杂度是:

二、Strassen算法

那么有没有比

更快的算法呢?

1969年,Volker Strassen提出了第一个算法时间复杂度低于

矩阵乘法算法,算法复杂度为

。从下图可知,Strassen算法只有在对于维数比较大的矩阵 (

) ,性能上才有很大的优势,可以减少很多乘法计算。

Strassen vs. 普通乘法

Strassen算法证明了矩阵乘法存在时间复杂度低于

的算法的存在,后续学者不断研究发现新的更快的算法,截止目前时间复杂度最低的矩阵乘法算法是Coppersmith-Winograd方法的一种扩展方法,其算法复杂度为

三、Strassen原理详解

假设矩阵

和矩阵

都是

的方矩阵,求

,如下所示:

其中

矩阵 C 可以通过下列公式求出:

从上述公式我们可以得出,计算2个

的矩阵相乘需要2个

的矩阵8次乘法和4次加法。我们使用

表示

矩阵乘法的时间复杂度,那么我们可以根据上面的分解得到下面的递推公式:

其中,

表示8次矩阵乘法,而且相乘的矩阵规模降到了

表示4次矩阵加法的时间复杂度以及合并矩阵

的时间复杂度。

最终可计算得到

可以看出每次递归操作都需要8次矩阵相乘,而这正是瓶颈的来源。相比加法,矩阵乘法是非常慢的,于是我们想到能不能减少矩阵相乘的次数呢?

答案是当然可以!!!

Strassen算法正是从这个角度出发,实现了降低算法复杂度!

Strassen实现步骤

实现步骤可以分为以下4步:

按上述方法将矩阵

分解(花费时间

如下创建10个

的矩阵

(花费时间

递归地计算7个矩阵积

,每个矩阵

都是

的。

注意,上述公式中只有中间一列需要计算。

通过

计算

,花费时间

综合可得如下递归式:

进而求出时间复杂度为:

四、Strassen算法的代码实现

类StrassenMatrixComputor提供了3个API供调用:

API

说明

_generateTrivalMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT);

普通矩阵乘法计算

_generateMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);

Strassen算法的矩阵乘法

_generateMatMulConstB(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);

Strassen算法的矩阵乘法(和MatMul的区别在于内存Buffer是否允许复用)

我们以_generateMatMul为例来学习下Strassen算法如何实现,可以分成如下几步:

第一步:使用Strassen算法收益判断

在矩阵操作中,因为需要对矩阵的维数进行扩展,涉及大量读写操作,这些读写操作都需要大量循环,如果读写次数超出使用Strassen乘法的收益的话,就得不偿失了,那么就使用普通的矩阵乘法。

/*

Compute the memory read / write cost for expand Matrix Mul need eSub*lSub*hSub*(1+1.0/CONVOLUTION_TILED_NUMBWR), Matrix Add/Sub need x*y*UNIT*3 (2 read 1 write)

*/

float saveCost = (eSub * lSub * hSub) * (1.0f + 1.0f / CONVOLUTION_TILED_NUMBWR) - 4 * (eSub * lSub) * 3 - 7 * (eSub * hSub * 3);

if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBWR || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f)

{

return _generateTrivialMatMul(AT, BT, CT);

}

第二步:分块

将矩阵

3个矩阵都分成4块:

auto aStride = AT->stride(0);

auto a11 = AT->host() + 0 * aUnit * eSub + 0 * aStride * lSub;

auto a12 = AT->host() + 0 * aUnit * eSub + 1 * aStride * lSub;

auto a21 = AT->host() + 1 * aUnit * eSub + 0 * aStride * lSub;

auto a22 = AT->host() + 1 * aUnit * eSub + 1 * aStride * lSub;

auto bStride = BT->stride(0);

auto b11 = BT->host() + 0 * bUnit * lSub + 0 * bStride * hSub;

auto b12 = BT->host() + 0 * bUnit * lSub + 1 * bStride * hSub;

auto b21 = BT->host() + 1 * bUnit * lSub + 0 * bStride * hSub;

auto b22 = BT->host() + 1 * bUnit * lSub + 1 * bStride * hSub;

auto cStride = CT->stride(0); auto c11 = CT->host() + 0 * aUnit * eSub + 0 * cStride * hSub;

auto c12 = CT->host() + 0 * aUnit * eSub + 1 * cStride * hSub;

auto c21 = CT->host() + 1 * aUnit * eSub + 0 * cStride * hSub;

auto c22 = CT->host() + 1 * aUnit * eSub + 1 * cStride * hSub;

第三步:分治和递归

Strassen算法核心就是分治思想。这一步可以写成下列所示伪代码:

1. If n = 1 Output A × B

2. Else

3. Compute A11,B11, . . . ,A22,B22 % by computing m = n/2

4. P1 Strassen(A11,B12 − B22)

5. P2 Strassen(A11 + A12,B22)

6. P3 Strassen(A21 + A22,B11)

7. P4 Strassen(A22,B21 − B11)

8. P5 Strassen(A11 + A22,B11 + B22)

9. P6 Strassen(A12 − A22,B21 + B22)

10. P7 Strassen(A11 − A21,B11 + B12)

11. C11 P5 + P4 − P2 + P6

12. C12 P1 + P2

13. C21 P3 + P4

14. C22 P1 + P5 − P3 − P7

15. Output C

16. End If

例如其中的一步代码如下所示:

{

// S1=A21+A22, T1=B12-B11, P5=S1T1

auto f = [a22, a21, b11, b12, xAddr, yAddr, eSub, lSub, hSub, aStride, bStride]() {

MNNMatrixAdd(xAddr, a21, a22, eSub * aUnit / 4, eSub * aUnit, aStride, aStride, lSub);

MNNMatrixSub(yAddr, b12, b11, lSub * bUnit / 4, lSub * bUnit, bStride, bStride, hSub);

};

mFunctions.emplace_back(f);

auto code = _generateMatMul(X.get(), Y.get(), C22.get(), currentDepth);

if (code != NO_ERROR)

{

return code;

}

}

递归执行,得到最终结果!

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值