Strassen算法

–《算法导论第三版》第4章 分治策略
Strassen算法的核心思想是令递归树稍微不那么茂盛一点儿, 即只递归进行7次而不是8次 n 2 \frac{n}{2} 2n × n 2 \frac{n}{2} 2n矩阵的乘法。减少一次矩阵乘法带来的代价可能是额外几次 n 2 \frac{n}{2} 2n × n 2 \frac{n}{2} 2n矩阵的加法,但只是常数次

假定将A,B和C均分解为4个 n 2 \frac{n}{2} 2n × n 2 \frac{n}{2} 2n的子矩阵:
A = [ A 11 A 12 A 21 A 22 ] , B = [ B 11 B 12 B 21 B 22 ] , C = [ C 11 C 12 C 21 C 22 ] A = \begin{bmatrix} A_{11} & A_{12} \\A_{21} & A_{22} \end{bmatrix}, B = \begin{bmatrix} B_{11} & B_{12} \\B_{21} & B_{22} \end{bmatrix}, C =\begin{bmatrix} C_{11} & C_{12} \\C_{21} & C_{22} \end{bmatrix} A=[A11A21A12A22],B=[B11B21B12B22],C=[C11C21C12C22]
因此可将公式 C = A ⋅ B C = A \cdot B C=AB改写成
[ C 11 C 12 C 21 C 22 ] = [ A 11 A 12 A 21 A 22 ] ⋅ [ B 11 B 12 B 21 B 22 ] \begin{bmatrix} C_{11} & C_{12} \\C_{21} & C_{22} \end{bmatrix} = \begin{bmatrix} A_{11} & A_{12} \\A_{21} & A_{22} \end{bmatrix} \cdot \begin{bmatrix} B_{11} & B_{12} \\B_{21} & B_{22} \end{bmatrix} [C11C21C12C22]=[A11A21A12A22][B11B21B12B22]
则:
C 11 = A 11 ⋅ B 11 + A 12 ⋅ B 21 C 12 = A 11 ⋅ B 12 + A 22 ⋅ B 21 C 21 = A 21 ⋅ B 11 + A 22 ⋅ B 21 C 22 = A 21 ⋅ B 12 + A 22 ⋅ B 22 C_{11} = A_{11} \cdot B_{11} + A_{12} \cdot B_{21}\\ C_{12} = A_{11} \cdot B_{12} + A_{22} \cdot B_{21}\\ C_{21} = A_{21} \cdot B_{11} + A_{22} \cdot B_{21}\\ C_{22} = A_{21} \cdot B_{12} + A_{22} \cdot B_{22} C11=A11B11+A12B21C12=A11B12+A22B21C21=A21B11+A22B21C22=A21B12+A22B22
Strassen算法包含四步

  1. 按上述方法将A,B,C分解(花费时间 Θ ( 1 ) \Theta(1) Θ(1))。

  2. 如下创建10个 n 2 \frac{n}{2} 2n × n 2 \frac{n}{2} 2n的矩阵 S 1 , S 2 , . . . , S 10 S_1, S_2, ..., S_{10} S1,S2,...,S10,(花费时间 Θ ( n 2 ) \Theta(n^2) Θ(n2))。
    S 1 = B 12 − B 22 S 2 = A 11 + A 12 S 3 = A 21 + A 22 S 4 = B 21 − B 11 S 5 = A 11 + A 22 S 6 = B 11 + B 22 S 7 = A 12 − A 22 S 8 = B 21 + B 22 S 9 = A 11 − A 21 S 10 = B 11 + B 12 S_1 = B_{12} - B_{22}\\ S_2 = A_{11} + A_{12}\\ S_3 = A_{21} + A_{22}\\ S_4 = B_{21} - B_{11}\\ S_5 = A_{11} + A_{22}\\ S_6 = B_{11} + B_{22}\\ S_7 = A_{12} - A_{22}\\ S_8 = B_{21} + B_{22}\\ S_9 = A_{11} - A_{21}\\ S_{10} = B_{11} + B_{12} S1=B12B22S2=A11+A12S3=A21+A22S4=B21B11S5=A11+A22S6=B11+B22S7=A12A22S8=B21+B22S9=A11A21S10=B11+B12

  3. 如下递归地计算7个矩阵积 P 1 , P 2 , . . . , P 7 P_1, P_2, ..., P_7 P1,P2,...,P7。每个矩阵 P i P_i Pi都是 n 2 \frac{n}{2} 2n × n 2 \frac{n}{2} 2n的。

P 1 = A 11 ⋅ S 1 = A 11 ⋅ B 12 − A 11 ⋅ B 22 P 2 = S 2 ⋅ B 22 = A 11 ⋅ B 22 + A 12 ⋅ B 22 P 3 = S 3 ⋅ B 11 = A 21 ⋅ B 11 + A 22 ⋅ B 11 P 4 = A 22 ⋅ S 4 = A 22 ⋅ B 21 − A 22 ⋅ B 11 P 5 = S 5 ⋅ S 6 = A 11 ⋅ B 11 + A 11 ⋅ B 22 + A 22 ⋅ B 11 + A 22 ⋅ B 22 P 6 = S 7 ⋅ S 8 = A 12 ⋅ B 21 + A 12 ⋅ B 22 − A 22 ⋅ B 21 − A 22 ⋅ B 22 P 7 = S 9 ⋅ S 10 = A 11 ⋅ B 11 + A 11 ⋅ B 12 − A 21 ⋅ B 11 − A 21 ⋅ B 12 P_1 = A_{11} \cdot S_1 = A_{11} \cdot B_{12} - A_{11} \cdot B_{22}\\ P_2 = S_2 \cdot B_{22} = A_{11} \cdot B_{22} + A_{12} \cdot B_{22}\\ P_3 = S_3 \cdot B_{11} = A_{21} \cdot B_{11} + A_{22} \cdot B_{11}\\ P_4 = A_{22} \cdot S_4 = A_{22}\cdot B_{21} - A_{22} \cdot B_{11}\\ P_5 = S_5 \cdot S_6 = A_{11} \cdot B_{11} + A_{11} \cdot B_{22} + A_{22} \cdot B_{11} + A_{22} \cdot B_{22}\\ P_6 = S_7 \cdot S_8 = A_{12} \cdot B_{21} + A{12} \cdot B_{22} - A_{22} \cdot B_{21} - A_{22} \cdot B_{22}\\ P_7 = S_9 \cdot S_{10}= A_{11} \cdot B_{11} + A_{11} \cdot B_{12} - A_{21} \cdot B_{11} - A_{21} \cdot B_{12} P1=A11S1=A11B12A11B22P2=S2B22=A11B22+A12B22P3=S3B11=A21B11+A22B11P4=A22S4=A22B21A22B11P5=S5S6=A11B11+A11B22+A22B11+A22B22P6=S7S8=A12B21+A12B22A22B21A22B22P7=S9S10=A11B11+A11B12A21B11A21B12
注意,上述公式中只有中间一列需要计算。

  1. 通过 P i P_i Pi计算 C 11 , C 12 , C 21 , C 22 C_{11}, C_{12}, C_{21}, C_{22} C11,C12,C21,C22,(花费时间 Θ ( n 2 ) \Theta(n^2) Θ(n2))。
    C 11 = P 5 + P 4 − P 2 + P 6 C 12 = P 1 + P 2 C 21 = P 3 + P 4 C 22 = P 5 + P 1 − P 3 − P 7 C_{11} = P_5 + P_4 - P_2 + P_6\\ C_{12} = P_1 + P_2\\ C_{21} = P_3 + P_4\\ C_{22} = P_5 + P_1 - P_3 - P_7 C11=P5+P4P2+P6C12=P1+P2C21=P3+P4C22=P5+P1P3P7

可得如下递归式:
T ( n ) = { Θ ( 1 ) 若 n = 1 7 T ( n 2 ) + Θ ( n 2 ) 若 n > 1 T(n) = \begin{cases}\Theta(1) & 若n = 1\\7T(\frac{n}{2}) + \Theta(n^2) & 若n> 1\end{cases} T(n)={Θ(1)7T(2n)+Θ(n2)n=1n>1
在这里插入图片描述
由递归树或主方法可得:
T ( n ) = Θ ( n l g 7 ) T(n) = \Theta(n^{lg7}) T(n)=Θ(nlg7)

C++实现如下:

#include "stdafx.h"
#include <stdio.h>
#include <iostream>
#include <windows.h>
#include <ctime>
using namespace std;
 
 
template <typename T>
class Strassen
{
public:
	void ADD(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size);
	void SUB(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size);
	void NormalMul(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size);
	void StrassenMul(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size);
	void FillMatrix(T **  MatrixA, T ** MatrixB, int size);//给A、B矩阵赋初值
	int   GetMatrixSum(T ** Matrix, int size);
	//用来计算矩阵各个元素的和,如果两种算法得出的矩阵的和相等则认为算法正确。
};
 
template <typename T>
void Strassen<T>::ADD(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];
		}
	}
}
 
template <typename T>
void Strassen<T>::SUB(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];
		}
	}
}

template <typename T>
void Strassen<T>::NormalMul(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			MatrixResult[i][j] = 0;
			for(int k = 0; k < size; k++)
				MatrixResult[i][j] += MatrixA[i][k] * MatrixB[k][j];
		}
	}
}
 
template <typename T>
void Strassen<T>::FillMatrix(T **  MatrixA, T ** MatrixB, int size)//给A、B矩阵赋初值
{
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			MatrixA[i][j] = MatrixB[i][j] = rand() % 5; 
		}
	}	
}
 
template <typename T>
void Strassen<T>::StrassenMul(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
	// if ( size <= 64 )    
	//分治门槛,小于这个值时不再进行递归计算,而是采用常规矩阵计算方法
	// {
	// 	NormalMul(MatrixA, MatrixB, MatrixResult, size);
	// }
	if(size == 1)
	{
		MatrixResult[0][0] = MatrixA[0][0] * MatrixB[0][0];
	}
	else
	{
		int half_size = size / 2;
		T ** A11; T ** A12; T ** A21; T ** A22;
		T ** B11; T ** B12; T ** B21; T ** B22;
		T ** C11; T ** C12; T ** C21; T** C22;
		T ** M1; T ** M2; T ** M3; T ** M4; T ** M5; T ** M6; T ** M7;
		T ** MatrixTemp1; T ** MatrixTemp2;
 
		A11 = new int * [half_size];
		A12 = new int * [half_size];
		A21 = new int * [half_size];
		A22 = new int * [half_size];
 
		B11 = new int * [half_size];
		B12 = new int * [half_size];
		B21 = new int * [half_size];
		B22 = new int * [half_size];
 
		C11 = new int * [half_size];
		C12 = new int * [half_size];
		C21 = new int * [half_size];
		C22 = new int * [half_size];
 
		M1 = new int * [half_size];
		M2 = new int * [half_size];
		M3 = new int * [half_size];
		M4 = new int * [half_size];
		M5 = new int * [half_size];
		M6 = new int * [half_size];
		M7 = new int * [half_size];
		MatrixTemp1 = new int * [half_size];
		MatrixTemp2 = new int * [half_size];
 
		for(int i = 0; i < half_size; i++)
		{
			A11[i] = new int[half_size];	
			A12[i] = new int[half_size];	
			A21[i] = new int[half_size];	
			A22[i] = new int[half_size];
			
			B11[i] = new int[half_size];	
			B12[i] = new int[half_size];	
			B21[i] = new int[half_size];	
			B22[i] = new int[half_size];
			
			C11[i] = new int[half_size];	
			C12[i] = new int[half_size];	
			C21[i] = new int[half_size];	
			C22[i] = new int[half_size];
 
			M1[i] = new int[half_size];	
			M2[i] = new int[half_size];	
			M3[i] = new int[half_size];	
			M4[i] = new int[half_size];
			M5[i] = new int[half_size];	
			M6[i] = new int[half_size];	
			M7[i] = new int[half_size];
 
			MatrixTemp1[i] = new int[half_size];	
			MatrixTemp2[i] = new int[half_size];
		}
 
		//赋值
		for(int i = 0; i < half_size; i++)
		{
			for(int j = 0; j < half_size; j++)
			{
				A11[i][j] = MatrixA[i][j];
				A12[i][j] = MatrixA[i][j+half_size];
				A21[i][j] = MatrixA[i+half_size][j];
				A22[i][j] = MatrixA[i+half_size][j+half_size];
 
				B11[i][j] = MatrixB[i][j];
				B12[i][j] = MatrixB[i][j+half_size];
				B21[i][j] = MatrixB[i+half_size][j];
				B22[i][j] = MatrixB[i+half_size][j+half_size];
			}
		}		
 
		//calculate M1
		ADD(A11, A22, MatrixTemp1, half_size);
		ADD(B11, B22, MatrixTemp2, half_size);
		StrassenMul(MatrixTemp1, MatrixTemp2, M1,half_size);
 
		//calculate M2
		ADD(A21, A22, MatrixTemp1, half_size);
		StrassenMul(MatrixTemp1, B11, M2, half_size);
 
		//calculate M3
		SUB(B12, B22, MatrixTemp1, half_size);
		StrassenMul(A11, MatrixTemp1, M3, half_size);
 
 
		//calculate M4
		SUB(B21, B11, MatrixTemp1, half_size);
		StrassenMul(A22, MatrixTemp1, M4, half_size);
 
		//calculate M5
		ADD(A11, A12, MatrixTemp1, half_size);
		StrassenMul(MatrixTemp1, B22, M5, half_size);
 
		//calculate M6
		SUB(A21, A11, MatrixTemp1, half_size);
		ADD(B11, B12, MatrixTemp2, half_size);
		StrassenMul(MatrixTemp1, MatrixTemp2, M6, half_size);
 
		//calculate M7
		SUB(A12, A22, MatrixTemp1, half_size);
		ADD(B21, B22, MatrixTemp2, half_size);
		StrassenMul(MatrixTemp1, MatrixTemp2, M7, half_size);
 
		//C11
		ADD(M1, M4, C11, half_size);
		SUB(C11, M5, C11, half_size);
		ADD(C11, M7, C11, half_size);
 
		//C12
		ADD(M3, M5, C12, half_size);
 
		//C21
		ADD(M2, M4, C21, half_size);
 
		//C22
		SUB(M1, M2, C22, half_size);
		ADD(C22, M3, C22, half_size);
		ADD(C22, M6, C22, half_size);
 
		//赋值
		for(int i = 0; i < half_size; i++)
		{
			for(int j = 0; j < half_size; j++)
			{
				MatrixResult[i][j] = C11[i][j];
				MatrixResult[i][j+half_size] = C12[i][j];
				MatrixResult[i+half_size][j] = C21[i][j];
				MatrixResult[i+half_size][j+half_size] = C22[i][j];
			}
		}
 
		//释放申请的内存
		for(int i = 0; i < half_size; i++)
		{
			delete[] A11[i];	
			delete[] A12[i];	
			delete[] A21[i];	
			delete[] A22[i];
			
			delete[] B11[i];	
			delete[] B12[i];	
			delete[] B21[i];	
			delete[] B22[i];
			
			delete[] C11[i];	
			delete[] C12[i];	
			delete[] C21[i];	
			delete[] C22[i];
 
			delete[] M1[i];	
			delete[] M2[i];	
			delete[] M3[i];	
			delete[] M4[i];	
			delete[] M5[i];	
			delete[] M6[i];	
			delete[] M7[i];
			
			delete[] MatrixTemp1[i];	
			delete[] MatrixTemp2[i];
		}
		delete[] A11;	
		delete[] A12;	
		delete[] A21;	
		delete[] A22;
		
		delete[] B11;	
		delete[] B12;	
		delete[] B21;	
		delete[] B22;
		
		delete[] C11;	
		delete[] C12;	
		delete[] C21;	
		delete[] C22;
 
		delete[] M1;	
		delete[] M2;	
		delete[] M3;	
		delete[] M4;	
		delete[] M5;	
		delete[] M6;	
		delete[] M7;
		
		delete[] MatrixTemp1;	
		delete[] MatrixTemp2;
	}
}
 
template <typename T>
int   Strassen<T>::GetMatrixSum(T ** Matrix, int size)
{
	int sum = 0;
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			sum += Matrix[i][j];
		}
	}	
	return sum;
}
 
int main()
{
	long startTime_normal, endTime_normal;
	long startTime_strasse, endTime_strassen;
 
	//srand(time(0));
 
	Strassen<int> stra;
	int N;
	cout<<"please input the size of Matrix,and the size must be the power of 2:"<<endl;
	cin>>N;
 
	int ** Matrix1 = new int * [N];
	int ** Matrix2 = new int * [N];
	int ** Matrix3 = new int * [N];
	for(int i=0;i<N;i++)
	{
		Matrix1[i] = new int[N];
		Matrix2[i] = new int[N];
		Matrix3[i] = new int[N];
	}
 
	stra.FillMatrix(Matrix1, Matrix2,N);
 
	cout << "朴素算法开始时间:" << (startTime_normal = clock()) << endl;
	stra.NormalMul(Matrix1, Matrix2, Matrix3,N);
	cout << "朴素算法结束时间:" << (endTime_normal = clock()) << endl;
	cout << "总时间:" << endTime_normal-startTime_normal << endl;
	cout << "sum = " << stra.GetMatrixSum(Matrix3,N) << ';' << endl;
 
	cout << "Strassen算法开始时间:" << (startTime_strasse= clock()) << endl;
	stra.StrassenMul(Matrix1,Matrix2,Matrix3,N);
	cout << "Strassen算法结束时间:" << (endTime_strassen = clock()) << endl;
	cout << "总时间:" << endTime_strassen-startTime_strasse << endl;
	cout << "sum = " << stra.GetMatrixSum(Matrix3,N) << ';' << endl;
}

  • 8
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值