Strassen矩阵乘法问题及延申

问题描述

在这里插入图片描述

参考答案

在这里插入图片描述

算法分析

由于m是奇数,不适用Strassen算法。2k 是2的倍数,适用于Strassen算法。所以计算nn阶矩阵可以先用传统算法计算mm个子矩阵的乘积,在用Strassen算法计算2k*2k矩阵之间的乘积,并在计算m1-m7可以利用Strassen算法计算子矩阵的子矩阵。

递归函数

分析

在这里插入图片描述如上图所示,在我们调用Strassen算法进行计算2k*2k阶矩阵相乘时,它的子矩阵同样适用于Strassen算法。

算法实现

Strassen乘法递归实现

#include <iostream>
#include <stdlib.h>
#include <math.h>
using namespace std;
//函数声明
int getK(int k);
int** Strassen(int** left, int** right, int k);
int** add(int** a1, int** a2, int n);
int** sub(int** a1, int** a2, int n);

/*
* @param argc 参数格式
* @param argv[0] 矩阵阶数
* @param argv[1]argv[2] 输入矩阵
*/
int main(int argc, char** argv)
{
	if (argc != 3)
		return -1;
	//n为矩阵阶数
	int n = (int)argv[0];
	//为输入矩阵申请空间
	int** left = (int**)malloc(n * sizeof(int*));
	int** right = (int**)malloc(n * sizeof(int*));
	for (int i = 0; i < n; i++) {
		left[i] = (int*)malloc(n * sizeof(int));
		right[i] = (int*)malloc(n * sizeof(int));
	}
	//传递参数
	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			left[i][j] = (int)argv[1][i*j + j];
			right[i][j] = (int)argv[2][i*j + j];
		}
	}
	//调用函数
    Strassen(left, right, getK(n));
}


//根据n = m*2k,求k
int getK(int n) {
	int k = 0;
	int target = 0;
	while (true)
	{
		target = n % 2;
		if (target == 1) {
			return k;
		}
		n = n / 2;
		k++;
	}
}

//矩阵加法
int** add(int** a1, int** a2, int n) {
	int** c = (int**)malloc(n * sizeof(int*));
	for (int i = 0; i < n; i++) {
		c[i] = (int*)malloc(n * sizeof(int));
	}

	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			c[i][j] = a1[i][j] + a2[i][j];
		}
	}
	return c;
}

//矩阵减法
int** sub(int** a1, int** a2, int n) {
	int** c = (int**)malloc(n * sizeof(int*));
	for (int i = 0; i < n; i++) {
		c[i] = (int*)malloc(n * sizeof(int));
	}

	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			c[i][j] = a1[i][j] - a2[i][j];
		}
	}
	return c;
}


//Strassen算法,递归求2k*2k矩阵
/*
* @param k:矩阵阶数为2k*2k
* @param left,right分别为乘号左边右边矩阵
*/
int** Strassen(int** left, int** right, int k) {
	//n为函数阶数
	int n = pow(2, k);
	//为结果矩阵申请内存空间
	int** result = (int**)malloc(n * sizeof(int*));
	for (int i = 0; i < n; i++) {
		result[i] = (int*)malloc(n * sizeof(int));
	}
	if (k == 1) {
		result[0][0] = left[0][0] * right[0][0] + left[0][1] * right[1][0];
		result[0][1] = left[0][0] * right[0][1] + left[0][1] * right[1][1];
		result[1][0] = left[1][0] * right[0][0] + left[1][1] * right[1][0];
		result[1][1] = left[1][0] * right[0][1] + left[1][1] * right[1][1];
		return result;
	}
	//将输入矩阵分为四份并申请空间,为计算结果m1-m7申请空间,为结果子矩阵申请空间
	int** left11 = (int**)malloc(n / 2 * sizeof(int*));
	int** left12 = (int**)malloc(n / 2 * sizeof(int*));
	int** left21 = (int**)malloc(n / 2 * sizeof(int*));
	int** left22 = (int**)malloc(n / 2 * sizeof(int*));
	int** right11 = (int**)malloc(n / 2 * sizeof(int*));
	int** right12 = (int**)malloc(n / 2 * sizeof(int*));
	int** right21 = (int**)malloc(n / 2 * sizeof(int*));
	int** right22 = (int**)malloc(n / 2 * sizeof(int*));

	int** m1 = (int**)malloc(n / 2 * sizeof(int*));
	int** m2 = (int**)malloc(n / 2 * sizeof(int*));
	int** m3 = (int**)malloc(n / 2 * sizeof(int*));
	int** m4 = (int**)malloc(n / 2 * sizeof(int*));
	int** m5 = (int**)malloc(n / 2 * sizeof(int*));
	int** m6 = (int**)malloc(n / 2 * sizeof(int*));
	int** m7 = (int**)malloc(n / 2 * sizeof(int*));

	int** result11 = (int**)malloc(n / 2 * sizeof(int*));
	int** result12 = (int**)malloc(n / 2 * sizeof(int*));
	int** result21 = (int**)malloc(n / 2 * sizeof(int*));
	int** result22 = (int**)malloc(n / 2 * sizeof(int*));

	for (int i = 0; i < n / 2; i++) {
		left11[i] = (int*)malloc(n / 2 * sizeof(int));
		left12[i] = (int*)malloc(n / 2 * sizeof(int));
		left21[i] = (int*)malloc(n / 2 * sizeof(int));
		left22[i] = (int*)malloc(n / 2 * sizeof(int));
		right11[i] = (int*)malloc(n / 2 * sizeof(int));
		right12[i] = (int*)malloc(n / 2 * sizeof(int));
		right21[i] = (int*)malloc(n / 2 * sizeof(int));
		right22[i] = (int*)malloc(n / 2 * sizeof(int));

		m1[i] = (int*)malloc(n / 2 * sizeof(int));
		m2[i] = (int*)malloc(n / 2 * sizeof(int));
		m3[i] = (int*)malloc(n / 2 * sizeof(int));
		m4[i] = (int*)malloc(n / 2 * sizeof(int));
		m5[i] = (int*)malloc(n / 2 * sizeof(int));
		m6[i] = (int*)malloc(n / 2 * sizeof(int));
		m7[i] = (int*)malloc(n / 2 * sizeof(int));

	}
	//复制内容到子矩阵
	for (int i = 0; i < n / 2; i++) {
		for (int j = 0; j < n / 2; j++) {
			left11[i][j] = left[i][j];
			left12[i][j] = left[i][j + n / 2];
			left21[i][j] = left[i + n / 2][j];
			left22[i][j] = left[i + n / 2][j + n / 2];
			right11[i][j] = right[i][j];
			right12[i][j] = right[i][j + n / 2];
			right21[i][j] = right[i + n / 2][j];
			right22[i][j] = right[i + n / 2][j + n / 2];
		}
	}

	//递归计算m1-m7
	m1 = Strassen(left11, sub(right12, right22, n / 2), k - 1);
	m2 = Strassen(add(left11, left12, n / 2), right22, k - 1);
	m3 = Strassen(add(left21, left22, n / 2), right11, k - 1);
	m4 = Strassen(left22, sub(right21, right11, n / 2), k - 1);
	m5 = Strassen(add(left11, left22, n / 2), add(right11, right22, n / 2), k - 1);
	m6 = Strassen(sub(left12, left22, n / 2), add(right21, right22, n / 2), k - 1);
	m7 = Strassen(sub(left11, left21, n / 2), add(right11, right12, n / 2), k - 1);

	//计算结果子矩阵
	result11 = add(m5, add(m6, sub(m4, m2, n / 2), n / 2), n / 2);
	result12 = add(m1, m2, n / 2);
	result21 = add(m3, m4, n / 2);
	result22 = add(sub(sub(m1, m7, n / 2), m3, n / 2), m5, n / 2);
	//将子矩阵结果放入结果矩阵
	for (int i = 0; i < n / 2; i++) {
		for (int j = 0; j < n / 2; j++) {
			result[i][j] = result11[i][j];
			result[i][j + n / 2] = result12[i][j];
			result[i + n / 2][j] = result21[i][j];
			result[i + n / 2][j + n / 2] = result22[i][j];
		}
	}
	//返回结果指针
	return result;
}

时间复杂度

分析

用传统方法求两个m阶矩阵的乘积需要计算O(m^3)次2k2k矩阵的乘积,用Strassen算法计算2k2k矩阵乘积需要计算时间为

复杂度计算

若T(n)为算法复杂度,n为矩阵阶数,m是符合条件奇数,k为正整数。使得: n = m 2 k n=m2^k n=m2k
则时间复杂度为:
T ( n ) = { O ( n 3 ) , k = 0 O ( 7 l o g 2 n m 3 ) , k > = 1 T(n)=\begin{cases} O(n^3),k=0\\ O(7log_2^nm^3),k>=1\\ \end{cases} T(n)={O(n3),k=0O(7log2nm3),k>=1

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值