问题描述
参考答案
算法分析
由于m是奇数,不适用Strassen算法。2k 是2的倍数,适用于Strassen算法。所以计算nn阶矩阵可以先用传统算法计算mm个子矩阵的乘积,在用Strassen算法计算2k*2k矩阵之间的乘积,并在计算m1-m7可以利用Strassen算法计算子矩阵的子矩阵。
递归函数
分析
如上图所示,在我们调用Strassen算法进行计算2k*2k阶矩阵相乘时,它的子矩阵同样适用于Strassen算法。
算法实现
Strassen乘法递归实现
#include <iostream>
#include <stdlib.h>
#include <math.h>
using namespace std;
//函数声明
int getK(int k);
int** Strassen(int** left, int** right, int k);
int** add(int** a1, int** a2, int n);
int** sub(int** a1, int** a2, int n);
/*
* @param argc 参数格式
* @param argv[0] 矩阵阶数
* @param argv[1]argv[2] 输入矩阵
*/
int main(int argc, char** argv)
{
if (argc != 3)
return -1;
//n为矩阵阶数
int n = (int)argv[0];
//为输入矩阵申请空间
int** left = (int**)malloc(n * sizeof(int*));
int** right = (int**)malloc(n * sizeof(int*));
for (int i = 0; i < n; i++) {
left[i] = (int*)malloc(n * sizeof(int));
right[i] = (int*)malloc(n * sizeof(int));
}
//传递参数
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
left[i][j] = (int)argv[1][i*j + j];
right[i][j] = (int)argv[2][i*j + j];
}
}
//调用函数
Strassen(left, right, getK(n));
}
//根据n = m*2k,求k
int getK(int n) {
int k = 0;
int target = 0;
while (true)
{
target = n % 2;
if (target == 1) {
return k;
}
n = n / 2;
k++;
}
}
//矩阵加法
int** add(int** a1, int** a2, int n) {
int** c = (int**)malloc(n * sizeof(int*));
for (int i = 0; i < n; i++) {
c[i] = (int*)malloc(n * sizeof(int));
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c[i][j] = a1[i][j] + a2[i][j];
}
}
return c;
}
//矩阵减法
int** sub(int** a1, int** a2, int n) {
int** c = (int**)malloc(n * sizeof(int*));
for (int i = 0; i < n; i++) {
c[i] = (int*)malloc(n * sizeof(int));
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c[i][j] = a1[i][j] - a2[i][j];
}
}
return c;
}
//Strassen算法,递归求2k*2k矩阵
/*
* @param k:矩阵阶数为2k*2k
* @param left,right分别为乘号左边右边矩阵
*/
int** Strassen(int** left, int** right, int k) {
//n为函数阶数
int n = pow(2, k);
//为结果矩阵申请内存空间
int** result = (int**)malloc(n * sizeof(int*));
for (int i = 0; i < n; i++) {
result[i] = (int*)malloc(n * sizeof(int));
}
if (k == 1) {
result[0][0] = left[0][0] * right[0][0] + left[0][1] * right[1][0];
result[0][1] = left[0][0] * right[0][1] + left[0][1] * right[1][1];
result[1][0] = left[1][0] * right[0][0] + left[1][1] * right[1][0];
result[1][1] = left[1][0] * right[0][1] + left[1][1] * right[1][1];
return result;
}
//将输入矩阵分为四份并申请空间,为计算结果m1-m7申请空间,为结果子矩阵申请空间
int** left11 = (int**)malloc(n / 2 * sizeof(int*));
int** left12 = (int**)malloc(n / 2 * sizeof(int*));
int** left21 = (int**)malloc(n / 2 * sizeof(int*));
int** left22 = (int**)malloc(n / 2 * sizeof(int*));
int** right11 = (int**)malloc(n / 2 * sizeof(int*));
int** right12 = (int**)malloc(n / 2 * sizeof(int*));
int** right21 = (int**)malloc(n / 2 * sizeof(int*));
int** right22 = (int**)malloc(n / 2 * sizeof(int*));
int** m1 = (int**)malloc(n / 2 * sizeof(int*));
int** m2 = (int**)malloc(n / 2 * sizeof(int*));
int** m3 = (int**)malloc(n / 2 * sizeof(int*));
int** m4 = (int**)malloc(n / 2 * sizeof(int*));
int** m5 = (int**)malloc(n / 2 * sizeof(int*));
int** m6 = (int**)malloc(n / 2 * sizeof(int*));
int** m7 = (int**)malloc(n / 2 * sizeof(int*));
int** result11 = (int**)malloc(n / 2 * sizeof(int*));
int** result12 = (int**)malloc(n / 2 * sizeof(int*));
int** result21 = (int**)malloc(n / 2 * sizeof(int*));
int** result22 = (int**)malloc(n / 2 * sizeof(int*));
for (int i = 0; i < n / 2; i++) {
left11[i] = (int*)malloc(n / 2 * sizeof(int));
left12[i] = (int*)malloc(n / 2 * sizeof(int));
left21[i] = (int*)malloc(n / 2 * sizeof(int));
left22[i] = (int*)malloc(n / 2 * sizeof(int));
right11[i] = (int*)malloc(n / 2 * sizeof(int));
right12[i] = (int*)malloc(n / 2 * sizeof(int));
right21[i] = (int*)malloc(n / 2 * sizeof(int));
right22[i] = (int*)malloc(n / 2 * sizeof(int));
m1[i] = (int*)malloc(n / 2 * sizeof(int));
m2[i] = (int*)malloc(n / 2 * sizeof(int));
m3[i] = (int*)malloc(n / 2 * sizeof(int));
m4[i] = (int*)malloc(n / 2 * sizeof(int));
m5[i] = (int*)malloc(n / 2 * sizeof(int));
m6[i] = (int*)malloc(n / 2 * sizeof(int));
m7[i] = (int*)malloc(n / 2 * sizeof(int));
}
//复制内容到子矩阵
for (int i = 0; i < n / 2; i++) {
for (int j = 0; j < n / 2; j++) {
left11[i][j] = left[i][j];
left12[i][j] = left[i][j + n / 2];
left21[i][j] = left[i + n / 2][j];
left22[i][j] = left[i + n / 2][j + n / 2];
right11[i][j] = right[i][j];
right12[i][j] = right[i][j + n / 2];
right21[i][j] = right[i + n / 2][j];
right22[i][j] = right[i + n / 2][j + n / 2];
}
}
//递归计算m1-m7
m1 = Strassen(left11, sub(right12, right22, n / 2), k - 1);
m2 = Strassen(add(left11, left12, n / 2), right22, k - 1);
m3 = Strassen(add(left21, left22, n / 2), right11, k - 1);
m4 = Strassen(left22, sub(right21, right11, n / 2), k - 1);
m5 = Strassen(add(left11, left22, n / 2), add(right11, right22, n / 2), k - 1);
m6 = Strassen(sub(left12, left22, n / 2), add(right21, right22, n / 2), k - 1);
m7 = Strassen(sub(left11, left21, n / 2), add(right11, right12, n / 2), k - 1);
//计算结果子矩阵
result11 = add(m5, add(m6, sub(m4, m2, n / 2), n / 2), n / 2);
result12 = add(m1, m2, n / 2);
result21 = add(m3, m4, n / 2);
result22 = add(sub(sub(m1, m7, n / 2), m3, n / 2), m5, n / 2);
//将子矩阵结果放入结果矩阵
for (int i = 0; i < n / 2; i++) {
for (int j = 0; j < n / 2; j++) {
result[i][j] = result11[i][j];
result[i][j + n / 2] = result12[i][j];
result[i + n / 2][j] = result21[i][j];
result[i + n / 2][j + n / 2] = result22[i][j];
}
}
//返回结果指针
return result;
}
时间复杂度
分析
用传统方法求两个m阶矩阵的乘积需要计算O(m^3)次2k2k矩阵的乘积,用Strassen算法计算2k2k矩阵乘积需要计算时间为
复杂度计算
若T(n)为算法复杂度,n为矩阵阶数,m是符合条件奇数,k为正整数。使得:
n
=
m
2
k
n=m2^k
n=m2k
则时间复杂度为:
T
(
n
)
=
{
O
(
n
3
)
,
k
=
0
O
(
7
l
o
g
2
n
m
3
)
,
k
>
=
1
T(n)=\begin{cases} O(n^3),k=0\\ O(7log_2^nm^3),k>=1\\ \end{cases}
T(n)={O(n3),k=0O(7log2nm3),k>=1