矩阵的乘法问题(分治算法)

矩阵的乘法问题(分治算法)

矩阵 C = A B C=AB C=AB可以表示为分块矩阵乘法:
[ C 11 C 12 C 21 C 22 ] = [ A 11 A 12 A 21 A 22 ] [ B 11 B 12 B 21 B 22 ] \begin{bmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{bmatrix}=\begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{bmatrix}\begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{bmatrix} [C11C21C12C22]=[A11A21A12A22][B11B21B12B22]
然后可以拆为以下子问题:
C 11 = A 11 B 11 + A 12 B 21 C_{11}=A_{11}B_{11}+A_{12}B_{21} C11=A11B11+A12B21
C 12 = A 11 B 12 + A 12 B 22 C_{12}=A_{11}B_{12}+A_{12}B_{22} C12=A11B12+A12B22
C 21 = A 21 B 11 + A 22 B 21 C_{21}=A_{21}B_{11}+A_{22}B_{21} C21=A21B11+A22B21
C 22 = A 21 B 12 + A 22 B 22 C_{22}=A_{21}B_{12}+A_{22}B_{22} C22=A21B12+A22B22
最小子问题为:矩阵大小为 2 × 2 2×2 2×2,可直接运算

代码实现

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

typedef struct matrixBlock
{
    int blocksize;
    int ** block11;
    int ** block12;
    int ** block21;
    int ** block22;
    matrixBlock(const int& _blocksize)
    {
        blocksize = _blocksize;
        block11 = new int * [blocksize];
        block12 = new int * [blocksize];
        block21 = new int * [blocksize];
        block22 = new int * [blocksize];
        for (int i = 0; i < blocksize; i++)
        {
            block11[i] = new int[blocksize]{0};
            block12[i] = new int[blocksize]{0};
            block21[i] = new int[blocksize]{0};
            block22[i] = new int[blocksize]{0};
        }
    }
}matrixBlock;

int ** matrixAddition(int ** A, int ** B, int n);
int ** getMatrixBlockByBound(const int ** squareMatrix, const int& xl, const int& xh, const int& yl, const int& yh);
void setMatrixBlockByBound(int ** resultMatrix, const int ** matblok, const int& xl, const int& xh, const int& yl, const int& yh);
matrixBlock * matrixSplitBlock(const int ** squareMatrix, const int& n);
int ** matrixMergeBlock(const int ** blok11, const int ** blok12, const int ** blok21, const int ** blok22, const int& n);
int ** matrixBaseMulti(int ** A, int ** B, int n);
int ** matrixMultiply(int ** A, int ** B, int n);
void matrixShow(const char * signstr, const int ** squareMatrix, const int& n);

int ** matrixAddition(int ** A, int ** B, int n)
{
    int ** matrixret = new int * [n];
    for (int i = 0; i < n; i++)
    {
        matrixret[i] = new int[n]{0};
        for (int j = 0; j < n; j++)
        {
            matrixret[i][j] = A[i][j] + B[i][j];
        }
    }
    return matrixret;
}

int ** getMatrixBlockByBound(const int ** squareMatrix, const int& xl, const int& xh, const int& yl, const int& yh)
{
    int squaresize = xh - xl;
    int ** matblok = new int * [squaresize];
    for (int i = 0; i < squaresize; i++)
    {
        matblok[i] = new int[squaresize]{0};        
    }
    for (int i = 0; i < squaresize; i++)
    {
        for (int j = 0; j < squaresize; j++)
        {
            matblok[i][j] = squareMatrix[i + xl][j + yl];
        }
    }
    return matblok;
}

void setMatrixBlockByBound(int ** resultMatrix, const int ** matblok, const int& xl, const int& xh, const int& yl, const int& yh)
{
    for (int i = xl; i < xh; i++)
    {
        for (int j = yl; j < yh; j++)
        {
            resultMatrix[i][j] = matblok[abs(xl - i)][abs(yl - j)];
        }
    }
}

matrixBlock * matrixSplitBlock(const int ** squareMatrix, const int& n)
{
    matrixBlock * matblok = new matrixBlock(n / 2);
    matblok->block11 = getMatrixBlockByBound(squareMatrix, 0, n / 2, 0, n / 2);
    matblok->block12 = getMatrixBlockByBound(squareMatrix, 0, n / 2, n / 2, n);
    matblok->block21 = getMatrixBlockByBound(squareMatrix, n / 2, n, 0, n / 2);
    matblok->block22 = getMatrixBlockByBound(squareMatrix, n / 2, n, n / 2, n);
    return matblok;
}

int ** matrixMergeBlock(const int ** blok11, const int ** blok12, const int ** blok21, const int ** blok22, const int& n)
{
    int ** matMerge = new int * [n];
    for (int i = 0; i < n; i++)
    {
        matMerge[i] = new int[n]{0};
    }
    setMatrixBlockByBound(matMerge, blok11, 0, n / 2, 0, n / 2);
    setMatrixBlockByBound(matMerge, blok12, 0, n / 2, n / 2, n);
    setMatrixBlockByBound(matMerge, blok21, n / 2, n, 0, n / 2);
    setMatrixBlockByBound(matMerge, blok22, n / 2, n, n / 2, n);
    return matMerge;
}

int ** matrixBaseMulti(int ** A, int ** B, int n)
{
    int ** matrixret = new int * [n];
    for (int i = 0 ; i < n; i++)
    {
        matrixret[i] = new int[n]{0};
        for (int j = 0; j < n; j++)
        {
            int tempvalue = 0;
            for (int k = 0; k < n; k++)
            {
                tempvalue += A[i][k] * B[k][j];
            }
            matrixret[i][j] = tempvalue;
        }
    }
    return matrixret;    
}

int ** matrixMultiply(int ** A, int ** B, int n)
{
    if (n == 2)
    {
        return matrixBaseMulti(A, B, n);
    }
    matrixBlock * matblokA = matrixSplitBlock((const int **)A, n);
    matrixBlock * matblokB = matrixSplitBlock((const int **)B, n);
    matrixBlock * matblokC = new matrixBlock(n / 2);
    matblokC->block11 = matrixAddition(matrixMultiply(matblokA->block11, matblokB->block11, n / 2), 
                        matrixMultiply(matblokA->block12, matblokB->block21, n / 2), n / 2);
    matblokC->block12 = matrixAddition(matrixMultiply(matblokA->block11, matblokB->block12, n / 2), 
                        matrixMultiply(matblokA->block12, matblokB->block22, n / 2), n / 2);
    matblokC->block21 = matrixAddition(matrixMultiply(matblokA->block21, matblokB->block11, n / 2), 
                        matrixMultiply(matblokA->block22, matblokB->block21, n / 2), n / 2);
    matblokC->block22 = matrixAddition(matrixMultiply(matblokA->block21, matblokB->block12, n / 2), 
                        matrixMultiply(matblokA->block22, matblokB->block22, n / 2), n / 2);
    return matrixMergeBlock((const int **)matblokC->block11, (const int **)matblokC->block12, (const int **)matblokC->block21, (const int **)matblokC->block22, n);
}

void matrixShow(const char * signstr, const int ** squareMatrix, const int& n)
{
    puts(signstr);
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < n; j++)
        {
            printf("%d ", squareMatrix[i][j]);
        }
        printf("\n");        
    }
    printf("\n");
}

int main()
{
    int matrixlength = 4;
    int ** matrixC = new int * [matrixlength];
    int ** matrixA = new int * [matrixlength];
    int ** matrixB = new int * [matrixlength];
    for (int i = 0; i < matrixlength; i++)
    {
        matrixC[i] = new int[matrixlength]{0};
        matrixA[i] = new int[matrixlength]{i, i + 1, i + 2, i + 3};
        matrixB[i] = new int[matrixlength]{i, 0, i + 1, 0};
    }
    
    matrixC = matrixMultiply(matrixA, matrixB, 4);
    matrixShow("A", (const int **)matrixA, matrixlength);
    matrixShow("B", (const int **)matrixB, matrixlength);
    matrixShow("C", (const int **)matrixC, matrixlength);
    return 0;
}
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值