算法分析与设计:分治递归实现斯特拉森(Strassen)矩阵乘法

分治问题

分治法的基本思想是将一个规模很大的问题分解为许多规模较小的子问题,这些子问题互相独立且与原问题相同。对这些子问题递归地进行求解后合并,得到原问题的解。


斯特拉森(Strassen)矩阵乘法

设A和B是两个 n × n n \times n n×n的矩阵,其中 n n n可以写成 2 k 2^k 2k。将A和B分别等分成4个小矩阵,此时如果把A和B都当成 2 × 2 2 \times 2 2×2矩阵来看,每个元素就是一个 ( n / 2 ) × ( n / 2 ) (\mathrm{n} / 2) \times(\mathrm{n} / 2) (n/2)×(n/2) 矩阵,而矩阵A和B的乘积就可以写成
( A 11 A 12 A 21 A 22 ) ( B 11 B 12 B 21 B 22 ) = ( C 11 C 12 C 21 C 22 ) \left(\begin{array}{ll} A_{11} & A_{12} \\ A_{21} & A_{22} \end{array}\right)\left(\begin{array}{ll} B_{11} & B_{12} \\ B_{21} & B_{22} \end{array}\right)=\left(\begin{array}{ll} C_{11} & C_{12} \\ C_{21} & C_{22} \end{array}\right) (A11A21A12A22)(B11B21B12B22)=(C11C21C12C22)
其中利用斯特拉森方法得到7个小矩阵,分别定义为:
M 1 = A 11 ( B 12 − B 22 ) M 2 = ( A 11 + A 12 ) B 22 M 3 = ( A 21 + A 22 ) B 11 M 4 = A 22 ( B 21 − B 11 ) M 5 = ( A 11 + A 22 ) ( B 11 + B 22 ) M 6 = ( A 12 − A 22 ) ( B 21 + B 22 ) M 7 = ( A 11 − A 21 ) ( B 11 + B 12 ) \begin{array}{rl} M_{1}=& A_{11}\left(B_{12}-B_{22}\right) \\ M_{2}=&\left(A_{11}+A_{12}\right) B_{22} \\ M_{3}=&\left(A_{21}+A_{22}\right) B_{11} \\ M_{4}=& A_{22}\left(B_{21}-B_{11}\right) \\ M_{5}=&\left(A_{11}+A_{22}\right)\left(B_{11}+B_{22}\right) \\ M_{6}=&\left(A_{12}-A_{22}\right)\left(B_{21}+B_{22}\right) \\ M_{7}=&\left(A_{11}-A_{21}\right)\left(B_{11}+B_{12}\right) \\ \end{array} M1=M2=M3=M4=M5=M6=M7=A11(B12B22)(A11+A12)B22(A21+A22)B11A22(B21B11)(A11+A22)(B11+B22)(A12A22)(B21+B22)(A11A21)(B11+B12)

矩阵 M 1 ∼ M 7 M_{1} \sim M_{7} M1M7 可以通过7次矩阵乘法、6次矩阵加法和4次矩阵减法计算得出,前述4个小矩阵 C 1 ∼ C 4 C_{1} \sim C_{4} C1C4 可以由矩阵 M 1 ∼ M 7 M_{1} \sim M_{7} M1M7 通过6次矩阵加法和2次矩阵减法得出,方法如下:
C 11 = M 4 + M 5 − M 2 + M 6 C 12 = M 1 + M 2 C 21 = M 3 + M 4 C 22 = M 5 + M 1 − M 3 − M 7 \begin{aligned} C_{11}=&M_{4}+M_{5}-M_{2}+M_{6}\\ C_{12}=&M_{1}+M_{2}\\ C_{21}=&M_{3}+M_{4}\\ C_{22}=&M_{5}+M_{1}-M_{3}-M_{7} \end{aligned} C11=C12=C21=C22=M4+M5M2+M6M1+M2M3+M4M5+M1M3M7


分析算法时间复杂度

T ( k ) T(k) T(k)是斯特拉森(Strassen)算法对两个 2 k 2^k 2k阶方阵进行运算所需的时间,Strassen算法中共使用了7次递归调用的矩阵乘法(即在生成 M 1 ∼ M 7 M_1\sim M_7 M1M7的7个矩阵的时候),共使用了18次矩阵加减运算(在生成 M 1 ∼ M 7 M_1\sim M_7 M1M7的时候使用了6次矩阵加法和4次矩阵减法,在生成 C 1 ∼ C 4 C_1\sim C_4 C1C4的时候使用了6次矩阵加法和2次矩阵减法)。所以其递归表达式如下:

{ T ( k ) = O ( 1 ) , ( k ≤ 1 ) T ( k ) = 7 T ( k − 1 ) + O ( 2 2 k ) , ( k ≥ 2 ) \left\{\begin{array}{ll}T(k)=O(1), & (k \leq 1) \\ T(k)=7 T(k-1)+O\left(2^{2 k}\right), & (k \geq 2)\end{array}\right. {T(k)=O(1),T(k)=7T(k1)+O(22k),(k1)(k2)

n = 2 k n=2^k n=2k进行求解: T ( n ) = O ( n l o g 7 ) T(n)=O(n^{log7} ) T(n)=O(nlog7),即斯特拉森算法的时间复杂度为 O ( n l o g 7 ) O(n^{log7} ) O(nlog7)。而传统的矩阵乘法中有三层 f o r for for循环,因此显然其时间复杂度为 O ( n 3 ) O(n^3) O(n3)。由此可见,Strassen矩阵乘法的计算时间复杂性比普通算法有着较大改进。


C++分治递归实现斯特拉森(Strassen)矩阵乘法

#include <iostream>
#include <string.h>

using namespace std;

/*****************************************************************
* 函数描述: 析取矩阵元素
* 函数参数: pM——矩阵指针
			nCol——使用矩阵列大小(指针取值使用)
			i——索引横坐标
			j——索引纵坐标
* 函数返回: 矩阵对应位置的元素值
*****************************************************************/
int& GetArrayVal(int* pM, int nCol, int i, int j)
{
	return *(pM + i * nCol + j);
}


/*****************************************************************
* 函数描述: 创建矩阵
			矩阵元素为 [0,5] 范围内的随机整数
* 函数参数: pM——矩阵指针
			nRow——创建矩阵的行规模
			nCol——创建矩阵的列规模
* 函数返回: void
*****************************************************************/
void CreateMatrix(int** pM, int nRow, int nCol)
{
	*pM = new int[nRow * nCol];
	memset(*pM, 0, sizeof(int*) * nRow * nCol);

	for (int i = 0; i < nRow; ++i)
	{
		for (int j = 0; j < nCol; ++j)
		{
			GetArrayVal(*pM, nCol, i, j) = rand() % 6;
		}
	}
}


/*****************************************************************
* 函数描述: 销毁矩阵(内存管理)
* 函数参数: pM——矩阵指针
* 函数返回:void
*****************************************************************/
void DeleteMatrix(int** pM)
{
	if (NULL != *pM)
	{
		delete* pM;
		*pM = NULL;
	}
}


/*****************************************************************
* 函数描述: 矩阵加减法(n阶方阵)
* 函数参数: pM1——矩阵1
			nLeftIndex1,nTopIndex1——矩阵1左上角索引点(相对于源矩阵pMl)
			nTotalCol1——矩阵1实际使用的列数
			pM2——矩阵2
			nLeftIndex2, nTopIndex2——矩阵2左上角索引点(相对于源矩阵pM2)
			nTotalCol2——矩阵2实际使用的列数
			nCount——方阵阶数n
			pResult——运算结果矩阵
			bAdd——加减标记
* 函数返回:void
*****************************************************************/
void MatrixAddOrSub(int* pM1, int nLeftIndex1, int nTopIndex1, int nTotalCol1,
	int* pM2, int nLeftIndex2, int nTopIndex2, int nTotalCol2,
	int nCount, int** pResult, bool bAdd)
{
	*pResult = new int[nCount * nCount];
	for (int i = 0; i < nCount; ++i)
	{
		for (int j = 0; j < nCount; ++j)
		{
			if (bAdd)	// 加法
			{
				GetArrayVal(*pResult, nCount, i, j) = GetArrayVal(pM1, nTotalCol1, nLeftIndex1 + i, nTopIndex1 + j)
					+ GetArrayVal(pM2, nTotalCol2, nLeftIndex2 + i, nTopIndex2 + j);
			}
			else		// 减法
			{
				GetArrayVal(*pResult, nCount, i, j) = GetArrayVal(pM1, nTotalCol1, nLeftIndex1 + i, nTopIndex1 + j)
					- GetArrayVal(pM2, nTotalCol2, nLeftIndex2 + i, nTopIndex2 + j);
			}

		}
	}
}


/*****************************************************************
* 函数描述: 矩阵乘法(n阶方阵)
* 函数参数: pM1——矩阵1
			nLeftIndex1,nTopIndex1——矩阵1左上角索引点(相对于源矩阵pMl)
			nTotalCol1——矩阵1实际使用的列数
			pM2——矩阵2
			nLeftIndex2, nTopIndex2——矩阵2左上角索引点(相对于源矩阵pM2)
			nTotalCol2——矩阵2实际使用的列数
			nCount——方阵阶数n
			pResult——运算结果矩阵
* 函数返回:void
*****************************************************************/
void MatrixMulti(int* pM1, int nLeftIndex1, int nTopIndex1, int nTotalCol1,
	int* pM2, int nLeftIndex2, int nTopIndex2, int nTotalCol2,
	int nCount, int** pResult)
{

	*pResult = new int[nCount * nCount];
	for (int i = 0; i < nCount; ++i)
	{
		for (int j = 0; j < nCount; ++j)
		{
			GetArrayVal(*pResult, nCount, i, j) = 0;
			for (int k = 0; k < nCount; ++k)
			{
				GetArrayVal(*pResult, nCount, i, j) += GetArrayVal(pM1, nTotalCol1, nLeftIndex1 + i, nTopIndex1 + k)
					* GetArrayVal(pM2, nTotalCol2, nLeftIndex2 + k, nTopIndex2 + j);
			}
		}
	}
}


/*****************************************************************
* 函数描述: 递归实现斯特拉森矩阵乘法(n阶方阵)
* 函数参数: pM1——矩阵1
			nLeftIndex1,nTopIndex1——矩阵1左上角索引点(相对于源矩阵pMl)
			nTotalCol1——矩阵1实际使用的列数
			pM2——矩阵2
			nLeftIndex2, nTopIndex2——矩阵2左上角索引点(相对于源矩阵pM2)
			nTotalCol2——矩阵2实际使用的列数
			nCount——方阵阶数n
			pResult——运算结果矩阵
* 函数返回:void
*****************************************************************/
void StrassenMatrix(int* pM1, int nLeftIndex1, int nTopIndex1, int nTotalCol1,
	int* pM2, int nLeftIndex2, int nTopIndex2, int nTotalCol2,
	int nCount, int** pResult)
{
	if (nCount == 2) // 如果当前为2阶,不能继续划分则跳出迭代
	{
		MatrixMulti(pM1, nLeftIndex1, nTopIndex1, nTotalCol1,
			pM2, nLeftIndex2, nTopIndex2, nTotalCol2, nCount, pResult);
	}
	else	// 如果当前大于2阶,拆分成4个大小相等的子矩阵,分别进行迭代
	{
		int* pResultM1 = NULL;
		int* pResultM2 = NULL;
		int* pResultM3 = NULL;
		int* pResultM4 = NULL;
		int* pResultM5 = NULL;
		int* pResultM6 = NULL;
		int* pResultM7 = NULL;

		// M1 = A11 * (B12 - B22)
		int* pB12_B22 = NULL;
		MatrixAddOrSub(pM2, nLeftIndex2, nTopIndex2 + nCount / 2, nTotalCol2,
			pM2, nLeftIndex2 + nCount / 2, nTopIndex2 + nCount / 2, nTotalCol2, nCount / 2, &pB12_B22, false);
		StrassenMatrix(pM1, nLeftIndex1, nTopIndex1, nTotalCol1,
			pB12_B22, 0, 0, nCount / 2, nCount / 2, &pResultM1);

		// M2 = (A11 + A12) * B22;
		int* pA11_A12 = NULL;
		MatrixAddOrSub(pM1, nLeftIndex1, nTopIndex1, nTotalCol1,
			pM1, nLeftIndex1, nTopIndex1 + nCount / 2, nTotalCol1, nCount / 2, &pA11_A12, true);
		StrassenMatrix(pA11_A12, 0, 0, nCount / 2,
			pM2, nLeftIndex2 + nCount / 2, nTopIndex2 + nCount / 2, nTotalCol2, nCount / 2, &pResultM2);

		// M3 = (A21 + A22) * B11;
		int* pA21_A22 = NULL;
		MatrixAddOrSub(pM1, nLeftIndex1 + nCount / 2, nTopIndex1, nTotalCol1,
			pM1, nLeftIndex1 + nCount / 2, nTopIndex1 + nCount / 2, nTotalCol1, nCount / 2, &pA21_A22, true);
		StrassenMatrix(pA21_A22, 0, 0, nCount / 2,
			pM2, nLeftIndex2, nTopIndex2, nTotalCol2, nCount / 2, &pResultM3);

		// M4 = A22 * (B21 - B11)
		int* pB21_B11 = NULL;
		MatrixAddOrSub(pM2, nLeftIndex2 + nCount / 2, nTopIndex2, nTotalCol2,
			pM2, nLeftIndex2, nTopIndex2, nTotalCol2, nCount / 2, &pB21_B11, false);
		StrassenMatrix(pM1, nLeftIndex1 + nCount / 2, nTopIndex1 + nCount / 2, nTotalCol1,
			pB21_B11, 0, 0, nCount / 2, nCount / 2, &pResultM4);


		// M5 = (A11 + A22) * (B11 + B22)
		int* pA11_A22 = NULL;
		int* pB11_B22 = NULL;
		MatrixAddOrSub(pM1, nLeftIndex1, nTopIndex1, nTotalCol1,
			pM1, nLeftIndex1 + nCount / 2, nTopIndex1 + nCount / 2, nTotalCol1, nCount / 2, &pA11_A22, true);
		MatrixAddOrSub(pM2, nLeftIndex2, nTopIndex2, nTotalCol2,
			pM2, nLeftIndex2 + nCount / 2, nTopIndex2 + nCount / 2, nTotalCol2, nCount / 2, &pB11_B22, true);
		StrassenMatrix(pA11_A22, 0, 0, nCount / 2,
			pB11_B22, 0, 0, nCount / 2, nCount / 2, &pResultM5);

		// M6 = (A12 - A22) * (B21 + B22)
		int* pA12_A22 = NULL;
		int* pB21_B22 = NULL;
		MatrixAddOrSub(pM1, nLeftIndex1, nTopIndex1 + nCount / 2, nTotalCol1,
			pM1, nLeftIndex1 + nCount / 2, nTopIndex1 + nCount / 2, nTotalCol1, nCount / 2, &pA12_A22, false);
		MatrixAddOrSub(pM2, nLeftIndex2 + nCount / 2, nTopIndex2, nTotalCol2,
			pM2, nLeftIndex2 + nCount / 2, nTopIndex2 + nCount / 2, nTotalCol2, nCount / 2, &pB21_B22, true);
		StrassenMatrix(pA12_A22, 0, 0, nCount / 2,
			pB21_B22, 0, 0, nCount / 2, nCount / 2, &pResultM6);

		// M7 = (A11 - A21) * (B11 + B12)
		int* pA11_A21 = NULL;
		int* pB11_B12 = NULL;
		MatrixAddOrSub(pM1, nLeftIndex1, nTopIndex1, nTotalCol1,
			pM1, nLeftIndex1 + nCount / 2, nTopIndex1, nTotalCol1, nCount / 2, &pA11_A21, false);
		MatrixAddOrSub(pM2, nLeftIndex2, nTopIndex2, nTotalCol2,
			pM2, nLeftIndex2, nTopIndex2 + nCount / 2, nTotalCol2, nCount / 2, &pB11_B12, true);
		StrassenMatrix(pA11_A21, 0, 0, nCount / 2,
			pB11_B12, 0, 0, nCount / 2, nCount / 2, &pResultM7);


		int* pResultC11 = NULL;
		int* pResultC12 = NULL;
		int* pResultC21 = NULL;
		int* pResultC22 = NULL;
		int* pResultTemp1 = NULL;
		int* pResultTemp2 = NULL;

		// C11 = M5 + M4 - M2 + M6
		MatrixAddOrSub(pResultM5, 0, 0, nCount / 2,
			pResultM4, 0, 0, nCount / 2, nCount / 2, &pResultTemp1, true);
		MatrixAddOrSub(pResultTemp1, 0, 0, nCount / 2,
			pResultM2, 0, 0, nCount / 2, nCount / 2, &pResultTemp2, false);
		MatrixAddOrSub(pResultTemp2, 0, 0, nCount / 2,
			pResultM6, 0, 0, nCount / 2, nCount / 2, &pResultC11, true);


		// C12 = M1 + M2
		MatrixAddOrSub(pResultM1, 0, 0, nCount / 2,
			pResultM2, 0, 0, nCount / 2, nCount / 2, &pResultC12, true);

		// C21 = M3 + M4 
		MatrixAddOrSub(pResultM3, 0, 0, nCount / 2,
			pResultM4, 0, 0, nCount / 2, nCount / 2, &pResultC21, true);

		// C22 = M5 + M1 - M3 - M7
		MatrixAddOrSub(pResultM5, 0, 0, nCount / 2,
			pResultM1, 0, 0, nCount / 2, nCount / 2, &pResultTemp1, true);
		MatrixAddOrSub(pResultTemp1, 0, 0, nCount / 2,
			pResultM3, 0, 0, nCount / 2, nCount / 2, &pResultTemp2, false);
		MatrixAddOrSub(pResultTemp2, 0, 0, nCount / 2,
			pResultM7, 0, 0, nCount / 2, nCount / 2, &pResultC22, false);


		// 构造结果
		*pResult = new int[nCount * nCount];
		for (int i = 0; i < nCount / 2; ++i)
		{
			for (int j = 0; j < nCount / 2; ++j)
			{
				GetArrayVal(*pResult, nCount, i, j) = GetArrayVal(pResultC11, nCount / 2, i, j);
				GetArrayVal(*pResult, nCount, i, j + nCount / 2) = GetArrayVal(pResultC12, nCount / 2, i, j);
				GetArrayVal(*pResult, nCount, i + nCount / 2, j) = GetArrayVal(pResultC21, nCount / 2, i, j);
				GetArrayVal(*pResult, nCount, i + nCount / 2, j + nCount / 2) = GetArrayVal(pResultC22, nCount / 2, i, j);

			}
		}

		//释放内存
		DeleteMatrix(&pResultM1);
		DeleteMatrix(&pResultM2);
		DeleteMatrix(&pResultM3);
		DeleteMatrix(&pResultM4);
		DeleteMatrix(&pResultM5);
		DeleteMatrix(&pResultM6);
		DeleteMatrix(&pResultM7);
		DeleteMatrix(&pA11_A12);
		DeleteMatrix(&pA21_A22);
		DeleteMatrix(&pB12_B22);
		DeleteMatrix(&pB21_B11);
		DeleteMatrix(&pA11_A22);
		DeleteMatrix(&pB11_B22);
		DeleteMatrix(&pA12_A22);
		DeleteMatrix(&pB21_B22);
		DeleteMatrix(&pA11_A21);
		DeleteMatrix(&pB11_B12);
		DeleteMatrix(&pResultTemp1);
		DeleteMatrix(&pResultTemp2);
		DeleteMatrix(&pResultC11);
		DeleteMatrix(&pResultC12);
		DeleteMatrix(&pResultC21);
		DeleteMatrix(&pResultC22);

	}
}

/*****************************************************************
* 函数描述: 打印矩阵
* 函数参数: pM——矩阵指针
			nRow——矩阵行规模
			nCol——矩阵列规模
* 函数返回:void
*****************************************************************/
void PrintMatrix(int* pM, int nRow, int nCol)
{
	for (int i = 0; i < nRow; ++i)
	{
		for (int j = 0; j < nCol; ++j)
		{
			cout << GetArrayVal(pM, nCol, i, j) << " ";
		}
		cout << endl;
	}
}

int main()
{
	srand(0);	// 随机数种子

	int* pM1 = NULL;
	int* pM2 = NULL;
	int* pMResult1 = NULL;
	int nRow1, nCol1, nRow2, nCol2;

	// 测试4阶方阵乘法
	nRow1 = 4;
	nCol1 = 4;
	nRow2 = 4;
	nCol2 = 4;

	// 测试8阶方阵乘法
	//nRow1 = 8;
	//nCol1 = 8;
	//nRow2 = 8;
	//nCol2 = 8;

	// 随机构建两个矩阵,并打印
	CreateMatrix(&pM1, nRow1, nCol1);
	CreateMatrix(&pM2, nRow2, nCol2);
	cout << "\nMatrix A:" << endl;
	PrintMatrix(pM1, nRow1, nCol1);
	cout << "\nMatrix B:" << endl;
	PrintMatrix(pM2, nRow2, nCol2);

	// 使用普通方法计算矩阵乘法
	cout << "\nGeneral matrix multiplication:" << endl;
	MatrixMulti(pM1, 0, 0, nRow1, pM2, 0, 0, nRow1, nRow1, &pMResult1);
	PrintMatrix(pMResult1, nRow1, nCol2);

	// 使用斯特拉森方法计算矩阵乘法
	cout << "\nStrassen matrix multiplication:" << endl;
	StrassenMatrix(pM1, 0, 0, nRow1, pM2, 0, 0, nRow1, nRow1, &pMResult1);
	PrintMatrix(pMResult1, nRow1, nCol2);

	//内存释放
	DeleteMatrix(&pM1);
	DeleteMatrix(&pM2);
	DeleteMatrix(&pMResult1);
	cout << "-----------------------------------" << endl;

	return 0;
}

代码输出

阶数为4的矩阵乘法

阶数为8的矩阵乘法

参考:https://blog.csdn.net/s634772208/article/details/46594707

  • 13
    点赞
  • 67
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值