Strassen方法求矩阵乘法
a: mxn b:nxk ,a和b相乘后c的维数为mxk
矩阵维数要求 m, n, k均为2的幂
#include<iostream>
using namespace std;
//typedef a data type to make the input matrix type flexible
typedef int DATATYPE;
//define a data structure to make the code simple
class SubMat
{
public:
DATATYPE *p;
int row, col;//total row and column
int subsr, subsc;//start row index and column index of sub matrix
int subRow, subCol;//row and column of sub matrix
SubMat(DATATYPE *datap,int dataRow,int dataCol) : subsr(0), subsc(0), subRow(0), subCol(0)
{
row = subRow = dataRow;
col = subCol = dataCol;
p = datap;
}
DATATYPE GetData(int i, int j)
{
return p[(subsr + i) * col + subsc + j];
}
void SetData(int i, int j, DATATYPE val)
{
p[(subsr + i) * col + subsc + j] = val;
}
};
void MatrixAddAB(SubMat& a,SubMat& b,SubMat& c)
{
for (int i = 0; i < a.subRow; i++)
{
for (int j = 0; j < a.subCol; j++)
{
c.SetData(i, j, a.GetData(i, j) + b.GetData(i, j));
}
}
}
void MatrixMinusAB(SubMat& a, SubMat& b, SubMat& c)
{
for (int i = 0; i < a.subRow; i++)
{
for (int j = 0; j < a.subCol; j++)
{
c.SetData(i, j, a.GetData(i, j) - b.GetData(i, j));
}
}
}
//recursive function to solve the matrix multiplication
void MatrixMultiplyAB(SubMat &a,SubMat &b,SubMat &c)
{
if (a.subCol == 0 || a.subRow == 0 || b.subRow == 0 || b.subCol == 0)
{
return;
}
else if (a.subCol == 1 || a.subRow == 1 || b.subRow == 1 || b.subCol == 1)
{
for (int i = 0; i < c.subRow; i++)
{
for (int j = 0; j < c.subCol;j++)
{
DATATYPE tmpSum = 0;
for (int k = 0; k < b.subRow;k++)
{
tmpSum += a.GetData(i, k)*b.GetData(k, j);
}
c.SetData(i, j, tmpSum);
}
}
return;
}
SubMat a11(a.p, a.row, a.col), a12(a.p, a.row, a.col), a21(a.p, a.row, a.col), a22(a.p, a.row, a.col),
b11(b.p, b.row, b.col), b12(b.p, b.row, b.col), b21(b.p, b.row, b.col), b22(b.p, b.row, b.col),
c11(c.p, c.row, c.col), c12(c.p, c.row, c.col), c21(c.p, c.row, c.col), c22(c.p, c.row, c.col);
int asubRow = a.subRow / 2, asubCol = a.subCol / 2;
int bsubRow = b.subRow / 2, bsubCol = b.subCol / 2;
a11.subsr = a.subsr, a11.subsc = a.subsc, a11.subRow = asubRow, a11.subCol = asubCol;
a12.subsr = a11.subsr, a12.subsc = a.subsc + asubCol, a12.subRow = asubRow, a12.subCol = asubCol;
a21.subsr = a.subsr + asubRow, a21.subsc = a.subsc, a21.subRow = asubRow, a21.subCol = asubCol;
a22.subsr = a21.subsr, a22.subsc = a12.subsc, a22.subRow = asubRow, a22.subCol = asubCol;
b11.subsr = b.subsr, b11.subsc = b.subsc, b11.subRow = bsubRow, b11.subCol = bsubCol;
b12.subsr = b11.subsr, b12.subsc = b.subsc + bsubCol, b12.subRow = bsubRow, b12.subCol = bsubCol;
b21.subsr = b.subsr + bsubRow, b21.subsc = b.subsc, b21.subRow = bsubRow, b21.subCol = bsubCol;
b22.subsr = b21.subsr, b22.subsc = b12.subsc, b22.subRow = bsubRow, b22.subCol = bsubCol;
c11.subsr = c.subsr, c11.subsc = c.subsc, c11.subRow = asubRow, c11.subCol = bsubCol;
c12.subsr = c.subsr, c12.subsc = c.subsc + bsubCol, c12.subRow = asubRow, c12.subCol = bsubCol;
c21.subsr = c.subsr + asubRow, c21.subsc = c.subsc, c21.subRow = asubRow, c21.subCol = bsubCol;
c22.subsr = c.subsr + asubRow, c22.subsc = c.subsc + bsubCol, c22.subRow = asubRow, c22.subCol = bsubCol;
/*S1 = B12 - B22, S2 = A11 + A12, S3 = A21 + A22,
S4 = B21 - B11, S5 = A11 + A22, S6 = B11 + B22,
S7 = A12 - A22, S8 = B21 + B22, S9 = A11 - A21,
S10 = B11 + B12*/
DATATYPE *s1 = new DATATYPE[bsubRow*bsubCol], *s2 = new DATATYPE[asubRow*asubCol], *s3 = new DATATYPE[asubRow*asubCol],
*s4 = new DATATYPE[bsubRow*bsubCol], *s5 = new DATATYPE[asubRow*asubCol], *s6 = new DATATYPE[bsubRow*bsubCol],
*s7 = new DATATYPE[asubRow*asubCol], *s8 = new DATATYPE[bsubRow*bsubCol], *s9 = new DATATYPE[asubRow*asubCol],
*s10 = new DATATYPE[bsubRow*bsubCol];
SubMat S1(s1, bsubRow, bsubCol), S2(s2, asubRow, asubCol), S3(s3, asubRow, asubCol),
S4(s4, bsubRow, bsubCol), S5(s5, asubRow, asubCol), S6(s6, bsubRow, bsubCol),
S7(s7, asubRow, asubCol), S8(s8, bsubRow, bsubCol), S9(s9, asubRow, asubCol),
S10(s10, bsubRow, bsubCol);
MatrixMinusAB(b12,b22,S1);
MatrixAddAB(a11,a12,S2);
MatrixAddAB(a21,a22,S3);
MatrixMinusAB(b21,b11,S4);
MatrixAddAB(a11,a22,S5);
MatrixAddAB(b11,b22,S6);
MatrixMinusAB(a12,a22,S7);
MatrixAddAB(b21,b22,S8);
MatrixMinusAB(a11,a21,S9);
MatrixAddAB(b11,b12,S10);
/*P1 = A11 * S1, P2 = S2 * B22, P3 = S3 * B11, P4 = A22 * S4
P5 = S5 * S6, P6 = S7 * S8, P7 = S9 * S10 */
int num = asubRow*bsubCol;
DATATYPE *p1 = new DATATYPE[num], *p2 = new DATATYPE[num], *p3 = new DATATYPE[num],
*p4 = new DATATYPE[num], *p5 = new DATATYPE[num], *p6 = new DATATYPE[num],
*p7 = new DATATYPE[num];
SubMat P1(p1, asubRow, bsubCol), P2(p2, asubRow, bsubCol), P3(p3, asubRow, bsubCol),
P4(p4, asubRow, bsubCol), P5(p5, asubRow, bsubCol), P6(p6, asubRow, bsubCol),
P7(p7, asubRow, bsubCol);
MatrixMultiplyAB(a11,S1,P1);
MatrixMultiplyAB(S2,b22,P2);
MatrixMultiplyAB(S3,b11,P3);
MatrixMultiplyAB(a22,S4,P4);
MatrixMultiplyAB(S5,S6,P5);
MatrixMultiplyAB(S7,S8,P6);
MatrixMultiplyAB(S9,S10,P7);
/*C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7*/
MatrixAddAB(P5, P4, c11);
MatrixMinusAB(c11, P2, c11);
MatrixAddAB(c11, P6, c11);
MatrixAddAB(P1,P2,c12);
MatrixAddAB(P3,P4,c21);
MatrixAddAB(P5,P1,c22);
MatrixMinusAB(c22,P3,c22);
MatrixMinusAB(c22,P7,c22);
delete[] s1, delete[] s2, delete[] s3, delete[] s4, delete[] s5, delete[] s6, delete[] s7, delete[] s8, delete[] s9, delete[] s10;
delete[] p1, delete[] p2, delete[] p3, delete[] p4, delete[] p5, delete[] p6, delete[] p7;
}
//package the recursive function to make the input parameters in a common 2-D array form
void MatrixMultipy(int* a, int arow, int acol, int* b, int brow, int bcol, int* c)
{
SubMat A(a, arow, acol), B(b, brow, bcol),C(c, arow, bcol);
MatrixMultiplyAB(A,B,C);
}
使用一个比较大的矩阵测试:
void main()
{
int a[16][8] = { { 1, 1, 1, 1, 2, 66, 456, 4 },{ 1,2,2,3,4,7,9,10 },{ 1,2,2,3,4,4,9,10 },{ 1,56,2,3,4,4,9,10 },{ 1, 1, 1, 1, 2, 66, 456, 4 },{ 1,2,2,3,4,7,9,10 },{ 1,2,2,3,4,4,9,10 },{ 1,56,2,3,4,4,9,10 }, { 1, 1, 1, 1, 2, 66, 456, 4 },{ 1,2,2,3,4,7,9,10 },{ 1,2,2,3,4,4,9,10 },{ 1,56,2,3,4,4,9,10 },{ 1, 1, 1, 1, 2, 66, 456, 4 },{ 1,2,2,3,4,7,9,10 },{ 1,2,2,3,4,4,9,10 },{ 1,56,2,3,4,4,9,10 } };
int b[8][4] = { { 1,5,6,1 },{ 3,22, 1, 1 },{ 3,5,0,3 },{ 4,24,11,5 },
{ 3,5,67,3 },{ 7,5,67,3 },{ 9,5,67,3 },{ 3,5,27,3 } };
int c[16][4];
MatrixMultipy(&a[0][0], 16, 8, &b[0][0], 8, 4, &c[0][0]);
for (int i = 0; i < 16;i++)
{
for (int j = 0; j < 4;j++)
{
cout << c[i][j] << "\t";
}
cout << endl;
}
system("pause");
}
测试结果:
使用matlab验证一下: