简单的矩阵优化算法

对于矩阵乘法   C =   A × B ,通常的做法是将矩阵进行分块相乘,如下图所示:

 

 

 

从上图可以看出这种分块相乘总共用了 8 次乘法,当然对于子矩阵相乘(如 A0 × B0 ),还可以继续递归使用分块相乘。对于中小矩阵来说,很适合使用这种分块乘法,但是对于大矩阵来说,递归的次数较多,如果能减少每次分块乘法的次数,那么性能将可以得到很好的提高。

Strassen 矩阵乘法就是采用了一个简单的运算技巧,将上面的 8 次矩阵相乘变成了 7 次乘法,看别小看这减少的 1 次乘法,因为每递归 1 次,性能就提高了 1/8 ,比如对于 1024*1024 的矩阵,第 1 次先分解成 7512*512 的矩阵相乘,对于 512*512 的矩阵,又可以继续递归分解成 256*256 的矩阵相乘, ,一直递归下去,假设分解到 64*64 的矩阵大小后就不再递归,那么所花的时间将是分块矩阵乘法的 (7/8) * (7/8) * (7/8) * (7/8) = 0.586 倍,提高了快接近一倍。当然这是理论上的值,因为实际上 strassen 乘法增加了其他运算开销,实际性能会略低一点。

下面就是 Strassen 矩阵乘法的实现方法,

    M1 = (A0 + A3) × (B0 + B3)

     M2 = (A2 + A3) × B0

    M3 = A0 × (B1 - B3)

    M4 = A3 × (B2 - B0)

    M5 = (A0 + A1) × B3

    M6 = (A2 - A0) × (B0 + B1)

    M7 = (A1 - A3) × (B2 + B3)

    C0 = M1 + M4 - M5 + M7

    C1 = M3 + M5

    C2 = M2 + M4

    C3 = M1 - M2 + M3 + M6

在求解 M1,M2,M3,M4,M5,M6,M7 时需要使用 7 次矩阵乘法,其他都是矩阵加法和减法。

下面看看 Strassen 矩阵乘法的串行实现伪代码:

Serial_StrassenMultiply(A, B, C) 

{

    T1 = A0 + A3;

    T2 = B0 + B3;

    StrassenMultiply(T1, T2, M1);

    T1 = A2 + A3;

    StrassenMultiply(T1, B0, M2);

    T1 = (B1 - B3);

    Strassen Multiply (A0, T1, M3);

 

    T1 = B2 - B0;

    StrassenMultiply(A3, T1, M4);

 

     T1 = A0 + A1;

     StrassenMultiply(T1, B3, M5);       

   

    T1 = A2 – A0;

    T2 = B0 + B1;

    StrassenMultiply(T1, T2, M6);

    T1 = A1 – A3;

    T2 = B2 + B3;

    StrassenMultiply(T1, T2, M7);

    C0 = M1 + M4 - M5 + M7

    C1 = M3 + M5

    C2 = M2 + M4

    C3 = M1 - M2 + M3 + M6

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值