矩阵乘法问题:strassen算法
Strassen算法的基本思想是把每一个矩阵都分为4块
在求C = AB,设7个矩阵变量。
M1 = (A12-A22)(B21+B22)
M2 = (A11+A22)(B11+B22)
M3 = (A11-A21)(B11+B12)
M4 = (A11+A12)B22
M5 = A11(B12-B11)
M6 = A22(B21-B11)
M7 = (A21+A22)B11。
则 C可以通过这7个变量算出。
C11 = M1+M2-M4+M6。
C12 = M4+N5。
C21 = M6+M7。
C22 = M2-M3+M5-M7。
就可以求出C。
主函数
int main() {
int **A, **B, **C;
int n;
cout<<"请输入矩阵的规模,将自动产生矩阵:";
cin >> n;
A = initMatrix(n); //初始化A,申请空间
randomMatrix(A, n); //A矩阵内容随机产生
B = initMatrix(n);
randomMatrix(B, n);
C = initMatrix(n); //矩阵C申请空间
printfMatrix(A, n); //打印输出矩阵
printfMatrix(B, n);
StrassenMatrix(A, B,C, n); //求C,C=AB,n为矩阵的规模
printfMatrix(C, n); //打印C
return 0;
}
功能函数
int** initMatrix(int n) {
int **Matrix = new int *[n];
for (int i = 0; i < n; i++) {
Matrix[i] = new int[n];
}
return Matrix;
}
void randomMatrix(int** Matrix,int n) {
srand(time(NULL));
Sleep(1000);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
Matrix[i][j] = rand() % 10;
}
}
}
void printfMatrix(int** Matrix,int n) {
for (int i = 0; i < n;i++) {
for (int j = 0; j < n; j++) {
cout << Matrix[i][j] << " ";
}
cout << endl;
}
}
//求两个矩阵相加结果
void AddMatrix(int** m1, int** m2, int** result, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result[i][j] = m1[i][j] + m2[i][j];
}
}
}
//求两个矩阵相减结果
void SubMatrix(int** m1, int** m2, int ** result, int n) {
//矩阵m1-m2
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result[i][j] = m1[i][j] - m2[i][j];
}
}
}
Strassen算法
void StrassenMatrix(int** A,int** B,int** C,int n) {
//C=AB; n为矩阵的规模
if (n==1) {
C[0][0] = A[0][0] * B[0][0];
}
else {
int m = n / 2; //缩小问题规模
//将ABC矩阵分块
int **A_11 = new int *[m];
int **A_12 = new int *[m];
int **A_21 = new int *[m];
int **A_22 = new int *[m];
int **B_11 = new int *[m];
int **B_12 = new int *[m];
int **B_21 = new int *[m];
int **B_22 = new int *[m];
int **C_11 = new int *[m];
int **C_12 = new int *[m];
int **C_21 = new int *[m];
int **C_22 = new int *[m];
//定义7个变量
int **M1 = new int *[m];
int **M2 = new int *[m];
int **M3 = new int *[m];
int **M4 = new int *[m];
int **M5 = new int *[m];
int **M6 = new int *[m];
int **M7 = new int *[m];
int **t1 = new int *[m];
int **t2 = new int *[m];
//分配存储空间
for (int i = 0; i < m; i++)
{
A_11[i] = new int[m];
A_12[i] = new int[m];
A_21[i] = new int[m];
A_22[i] = new int[m];
B_11[i] = new int[m];
B_12[i] = new int[m];
B_21[i] = new int[m];
B_22[i] = new int[m];
C_11[i] = new int[m];
C_12[i] = new int[m];
C_21[i] = new int[m];
C_22[i] = new int[m];
M1[i] = new int[m];
M2[i] = new int[m];
M3[i] = new int[m];
M4[i] = new int[m];
M5[i] = new int[m];
M6[i] = new int[m];
M7[i] = new int[m];
t1[i] = new int[m];
t2[i] = new int[m];
}
//将A,B分块
for (int i = 0; i < m; i++)
{
for (int j = 0; j < m; j++)
{
A_11[i][j] = A[i][j];
A_12[i][j] = A[i][j + m];
A_21[i][j] = A[i + m][j];
A_22[i][j] = A[i + m][j + m];
B_11[i][j] = B[i][j];
B_12[i][j] = B[i][j + m];
B_21[i][j] = B[i + m][j];
B_22[i][j] = B[i + m][j + m];
}
}
//M1 = (A12 - A22)(B21 + B22)
SubMatrix(A_12, A_22, t1, m);
AddMatrix(B_21, B_22, t2, m);
StrassenMatrix(t1, t2, M1, m);
//M2 = (A11 + A22)(B11 + B22)
AddMatrix(A_11, A_22, t1, m);
AddMatrix(B_11, B_22, t2, m);
StrassenMatrix(t1, t2, M2, m);
//M3 = (A11 - A21)(B11 + B12)
SubMatrix(A_11, A_21, t1, m);
AddMatrix(B_11, B_12, t2, m);
StrassenMatrix(t1, t2, M3, m);
//M4 = (A11 + A12)B22
AddMatrix(A_11, A_12, t1, m);
StrassenMatrix(t1, B_22, M4, m);
//M5 = A11(B12 - B22)
SubMatrix(B_12, B_22, t1, m);
StrassenMatrix(t1, A_11, M5, m);
//M6 = A22(B21 - B11),
SubMatrix(B_21, B_11, t1, m);
StrassenMatrix(A_22, t1, M6, m);
//M7 = (A21 + A22)B11。
AddMatrix(A_21, A_22, t1, m);
StrassenMatrix(t1, B_11, M7, m);
//根据M1到M7,求C_11,C_12,C_21,C_22
//C11 = M1 + M2 - M4 + M6。
AddMatrix(M1, M2, t1, m);
AddMatrix(t1, M6, t2, m);
SubMatrix(t2, M4, C_11, m);
//C12 = M4 + M5。
AddMatrix(M4, M5, C_12, m);
//C21 = M6 + M7。
AddMatrix(M6, M7, C_21, m);
//C22 = M2 - M3 + M5 - M7。
SubMatrix(M2, M3, t1, m);
SubMatrix(M5, M7, t2, m);
AddMatrix(t1, t2, C_22, m);
//求出C11,C12,C21,C22后拼接回C;
for (int i = 0; i < m;i++) {
for (int j = 0; j < m;j++) {
C[i][j] = C_11[i][j];
C[i][j + m] = C_12[i][j];
C[i + m][j] = C_21[i][j];
C[i + m][j + m] = C_22[i][j];
}
}
//释放所有申请的空间
for (int i = 0; i < m; i++)
{
delete[] A_11[i];
delete[] A_12[i];
delete[] A_21[i];
delete[] A_22[i];
delete[] B_11[i];
delete[] B_12[i];
delete[] B_21[i];
delete[] B_22[i];
delete[] C_11[i];
delete[] C_12[i];
delete[] C_21[i];
delete[] C_22[i];
delete[] M1[i];
delete[] M2[i];
delete[] M3[i];
delete[] M4[i];
delete[] M5[i];
delete[] M6[i];
delete[] M7[i];
delete[] t1[i];
delete[] t2[i];
}
delete[] A_11;
delete[] A_12;
delete[] A_21;
delete[] A_22;
delete[] B_11;
delete[] B_12;
delete[] B_21;
delete[] B_22;
delete[] C_11;
delete[] C_12;
delete[] C_21;
delete[] C_22;
delete[] M1;
delete[] M2;
delete[] M3;
delete[] M4;
delete[] M5;
delete[] M6;
delete[] M7;
delete[] t1;
delete[] t2;
}
}