本篇文章大部分思路与代码都来自于微信公众号“CPP开发者”中2016年4月11日的文章《矩阵相乘优化算法实现讲解》,基本相当于这篇文章的重点重述。
矩阵是什么以及矩阵乘法是怎么操作的,我想点开这篇文章的人都应该知道了,这里就不再赘述了。
首先回顾一下我们最朴素的算法:
//计算矩阵a乘矩阵b,将结果存入c;p是第一个矩阵的行数,q是第二个矩阵的行数,r是第二个矩阵的列数
void mult(int a[MAXN][MAXN],int b[MAXN][MAXN],int c[MAXN][MAXN],int p,int q,int r)
{
int i,j,k;
//先对c进行初始化
for(i=0;i<p;i++)
{
for(j=0;j<r;j++)
{
c[i][j] = 0;
}
}
//计算矩阵乘法
for(i=0;i<p;i++)
{
for(j=0;j<r;j++)
{
for(k=0;k<q;k++)
{
c[i][j] += a[i][k] * b[k][j];
}
}
}
}
这个算法就是直接模拟矩阵乘法的定义,时间复杂度是O(n^3),同时也是Ω(n^3)。
接下来介绍优化算法:
这个优化算法的最差时间复杂度也是O(n^3),但是对于矩阵中零比较多的情况会有所改善。
基本思路是遍历其中一个矩阵的所有元素,计算所有结果中用到这个元素的部分。如果这个元素是零,那么就没有必要计算了,略过去。这么说可能不清楚,所以还是还是那个代码吧。
int mult(int a[MAXN][MAXN],int b[MAXN][MAXN],int c[MAXN][MAXN],int p,int q,int r)
{
int i,j,k;
for(i=0;i<p;i++)
{
for(j=0;j<r;j++)
{
c[i][j] = 0;
}
}
for(i=0;i<p;i++)
{
for(k=0;k<q;k++)
{
if(a[i][k]!=0) //如果该元素是零,就省去以下计算
{
for(j=0;j<r;j++)
{
c[i][j] += a[i][k] * b[k][j];
}
}
}
}
}
比起其他最差时间复杂度有有效降低的算法,这一优化算法更便于实现,而且对于零比较多的矩阵会有很好的效果。