矩阵的乘法问题(分治算法)
矩阵
C
=
A
B
C=AB
C=AB可以表示为分块矩阵乘法:
[
C
11
C
12
C
21
C
22
]
=
[
A
11
A
12
A
21
A
22
]
[
B
11
B
12
B
21
B
22
]
\begin{bmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{bmatrix}=\begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{bmatrix}\begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{bmatrix}
[C11C21C12C22]=[A11A21A12A22][B11B21B12B22]
然后可以拆为以下子问题:
C
11
=
A
11
B
11
+
A
12
B
21
C_{11}=A_{11}B_{11}+A_{12}B_{21}
C11=A11B11+A12B21
C
12
=
A
11
B
12
+
A
12
B
22
C_{12}=A_{11}B_{12}+A_{12}B_{22}
C12=A11B12+A12B22
C
21
=
A
21
B
11
+
A
22
B
21
C_{21}=A_{21}B_{11}+A_{22}B_{21}
C21=A21B11+A22B21
C
22
=
A
21
B
12
+
A
22
B
22
C_{22}=A_{21}B_{12}+A_{22}B_{22}
C22=A21B12+A22B22
最小子问题为:矩阵大小为
2
×
2
2×2
2×2,可直接运算
代码实现
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
typedef struct matrixBlock
{
int blocksize;
int ** block11;
int ** block12;
int ** block21;
int ** block22;
matrixBlock(const int& _blocksize)
{
blocksize = _blocksize;
block11 = new int * [blocksize];
block12 = new int * [blocksize];
block21 = new int * [blocksize];
block22 = new int * [blocksize];
for (int i = 0; i < blocksize; i++)
{
block11[i] = new int[blocksize]{0};
block12[i] = new int[blocksize]{0};
block21[i] = new int[blocksize]{0};
block22[i] = new int[blocksize]{0};
}
}
}matrixBlock;
int ** matrixAddition(int ** A, int ** B, int n);
int ** getMatrixBlockByBound(const int ** squareMatrix, const int& xl, const int& xh, const int& yl, const int& yh);
void setMatrixBlockByBound(int ** resultMatrix, const int ** matblok, const int& xl, const int& xh, const int& yl, const int& yh);
matrixBlock * matrixSplitBlock(const int ** squareMatrix, const int& n);
int ** matrixMergeBlock(const int ** blok11, const int ** blok12, const int ** blok21, const int ** blok22, const int& n);
int ** matrixBaseMulti(int ** A, int ** B, int n);
int ** matrixMultiply(int ** A, int ** B, int n);
void matrixShow(const char * signstr, const int ** squareMatrix, const int& n);
int ** matrixAddition(int ** A, int ** B, int n)
{
int ** matrixret = new int * [n];
for (int i = 0; i < n; i++)
{
matrixret[i] = new int[n]{0};
for (int j = 0; j < n; j++)
{
matrixret[i][j] = A[i][j] + B[i][j];
}
}
return matrixret;
}
int ** getMatrixBlockByBound(const int ** squareMatrix, const int& xl, const int& xh, const int& yl, const int& yh)
{
int squaresize = xh - xl;
int ** matblok = new int * [squaresize];
for (int i = 0; i < squaresize; i++)
{
matblok[i] = new int[squaresize]{0};
}
for (int i = 0; i < squaresize; i++)
{
for (int j = 0; j < squaresize; j++)
{
matblok[i][j] = squareMatrix[i + xl][j + yl];
}
}
return matblok;
}
void setMatrixBlockByBound(int ** resultMatrix, const int ** matblok, const int& xl, const int& xh, const int& yl, const int& yh)
{
for (int i = xl; i < xh; i++)
{
for (int j = yl; j < yh; j++)
{
resultMatrix[i][j] = matblok[abs(xl - i)][abs(yl - j)];
}
}
}
matrixBlock * matrixSplitBlock(const int ** squareMatrix, const int& n)
{
matrixBlock * matblok = new matrixBlock(n / 2);
matblok->block11 = getMatrixBlockByBound(squareMatrix, 0, n / 2, 0, n / 2);
matblok->block12 = getMatrixBlockByBound(squareMatrix, 0, n / 2, n / 2, n);
matblok->block21 = getMatrixBlockByBound(squareMatrix, n / 2, n, 0, n / 2);
matblok->block22 = getMatrixBlockByBound(squareMatrix, n / 2, n, n / 2, n);
return matblok;
}
int ** matrixMergeBlock(const int ** blok11, const int ** blok12, const int ** blok21, const int ** blok22, const int& n)
{
int ** matMerge = new int * [n];
for (int i = 0; i < n; i++)
{
matMerge[i] = new int[n]{0};
}
setMatrixBlockByBound(matMerge, blok11, 0, n / 2, 0, n / 2);
setMatrixBlockByBound(matMerge, blok12, 0, n / 2, n / 2, n);
setMatrixBlockByBound(matMerge, blok21, n / 2, n, 0, n / 2);
setMatrixBlockByBound(matMerge, blok22, n / 2, n, n / 2, n);
return matMerge;
}
int ** matrixBaseMulti(int ** A, int ** B, int n)
{
int ** matrixret = new int * [n];
for (int i = 0 ; i < n; i++)
{
matrixret[i] = new int[n]{0};
for (int j = 0; j < n; j++)
{
int tempvalue = 0;
for (int k = 0; k < n; k++)
{
tempvalue += A[i][k] * B[k][j];
}
matrixret[i][j] = tempvalue;
}
}
return matrixret;
}
int ** matrixMultiply(int ** A, int ** B, int n)
{
if (n == 2)
{
return matrixBaseMulti(A, B, n);
}
matrixBlock * matblokA = matrixSplitBlock((const int **)A, n);
matrixBlock * matblokB = matrixSplitBlock((const int **)B, n);
matrixBlock * matblokC = new matrixBlock(n / 2);
matblokC->block11 = matrixAddition(matrixMultiply(matblokA->block11, matblokB->block11, n / 2),
matrixMultiply(matblokA->block12, matblokB->block21, n / 2), n / 2);
matblokC->block12 = matrixAddition(matrixMultiply(matblokA->block11, matblokB->block12, n / 2),
matrixMultiply(matblokA->block12, matblokB->block22, n / 2), n / 2);
matblokC->block21 = matrixAddition(matrixMultiply(matblokA->block21, matblokB->block11, n / 2),
matrixMultiply(matblokA->block22, matblokB->block21, n / 2), n / 2);
matblokC->block22 = matrixAddition(matrixMultiply(matblokA->block21, matblokB->block12, n / 2),
matrixMultiply(matblokA->block22, matblokB->block22, n / 2), n / 2);
return matrixMergeBlock((const int **)matblokC->block11, (const int **)matblokC->block12, (const int **)matblokC->block21, (const int **)matblokC->block22, n);
}
void matrixShow(const char * signstr, const int ** squareMatrix, const int& n)
{
puts(signstr);
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
printf("%d ", squareMatrix[i][j]);
}
printf("\n");
}
printf("\n");
}
int main()
{
int matrixlength = 4;
int ** matrixC = new int * [matrixlength];
int ** matrixA = new int * [matrixlength];
int ** matrixB = new int * [matrixlength];
for (int i = 0; i < matrixlength; i++)
{
matrixC[i] = new int[matrixlength]{0};
matrixA[i] = new int[matrixlength]{i, i + 1, i + 2, i + 3};
matrixB[i] = new int[matrixlength]{i, 0, i + 1, 0};
}
matrixC = matrixMultiply(matrixA, matrixB, 4);
matrixShow("A", (const int **)matrixA, matrixlength);
matrixShow("B", (const int **)matrixB, matrixlength);
matrixShow("C", (const int **)matrixC, matrixlength);
return 0;
}