Strassen方法求矩阵乘法

Strassen方法求矩阵乘法
a: mxn b:nxk ,a和b相乘后c的维数为mxk
矩阵维数要求 m, n, k均为2的幂

#include<iostream>
using namespace std;
//typedef a data type to make the input matrix type flexible
typedef int DATATYPE;
//define a data structure to make the code simple
class SubMat
{
public:
    DATATYPE *p;
    int row, col;//total row and column
    int subsr, subsc;//start row index and column index of sub matrix
    int subRow, subCol;//row and column of sub matrix
    SubMat(DATATYPE *datap,int dataRow,int dataCol) : subsr(0), subsc(0), subRow(0), subCol(0) 
    { 
        row = subRow = dataRow;
        col = subCol = dataCol;
        p = datap; 
    }
    DATATYPE GetData(int i, int j)
    {
        return p[(subsr + i) * col + subsc + j];
    }
    void SetData(int i, int j, DATATYPE val)
    {
        p[(subsr + i) * col + subsc + j] = val;
    }
};
void MatrixAddAB(SubMat& a,SubMat& b,SubMat& c)
{
    for (int i = 0; i < a.subRow; i++)
    {
        for (int j = 0; j < a.subCol; j++)
        {
            c.SetData(i, j, a.GetData(i, j) + b.GetData(i, j));
        }
    }
}
void MatrixMinusAB(SubMat& a, SubMat& b, SubMat& c)
{
    for (int i = 0; i < a.subRow; i++)
    {
        for (int j = 0; j < a.subCol; j++)
        {
            c.SetData(i, j, a.GetData(i, j) - b.GetData(i, j));
        }
    }
}
//recursive function to solve the matrix multiplication
void MatrixMultiplyAB(SubMat &a,SubMat &b,SubMat &c)
{
    if (a.subCol == 0 || a.subRow == 0 || b.subRow == 0 || b.subCol == 0)
    {
        return;
    }
    else if (a.subCol == 1 || a.subRow == 1 || b.subRow == 1 || b.subCol == 1)
    {

        for (int i = 0; i < c.subRow; i++)
        {
            for (int j = 0; j < c.subCol;j++)
            {
                DATATYPE tmpSum = 0;
                for (int k = 0; k < b.subRow;k++)
                {
                    tmpSum += a.GetData(i, k)*b.GetData(k, j);
                }
                c.SetData(i, j, tmpSum);
            }
        }
        return;
    }
    SubMat a11(a.p, a.row, a.col), a12(a.p, a.row, a.col), a21(a.p, a.row, a.col), a22(a.p, a.row, a.col),
        b11(b.p, b.row, b.col), b12(b.p, b.row, b.col), b21(b.p, b.row, b.col), b22(b.p, b.row, b.col),
        c11(c.p, c.row, c.col), c12(c.p, c.row, c.col), c21(c.p, c.row, c.col), c22(c.p, c.row, c.col);
    int asubRow = a.subRow / 2, asubCol = a.subCol / 2;
    int bsubRow = b.subRow / 2, bsubCol = b.subCol / 2;
    a11.subsr = a.subsr, a11.subsc = a.subsc, a11.subRow = asubRow, a11.subCol = asubCol;
    a12.subsr = a11.subsr, a12.subsc = a.subsc + asubCol, a12.subRow = asubRow, a12.subCol = asubCol;
    a21.subsr = a.subsr + asubRow, a21.subsc = a.subsc, a21.subRow = asubRow, a21.subCol = asubCol;
    a22.subsr = a21.subsr, a22.subsc = a12.subsc, a22.subRow = asubRow, a22.subCol = asubCol;
    b11.subsr = b.subsr, b11.subsc = b.subsc, b11.subRow = bsubRow, b11.subCol = bsubCol;
    b12.subsr = b11.subsr, b12.subsc = b.subsc + bsubCol, b12.subRow = bsubRow, b12.subCol = bsubCol;
    b21.subsr = b.subsr + bsubRow, b21.subsc = b.subsc, b21.subRow = bsubRow, b21.subCol = bsubCol;
    b22.subsr = b21.subsr, b22.subsc = b12.subsc, b22.subRow = bsubRow, b22.subCol = bsubCol;
    c11.subsr = c.subsr, c11.subsc = c.subsc, c11.subRow = asubRow, c11.subCol = bsubCol;
    c12.subsr = c.subsr, c12.subsc = c.subsc + bsubCol, c12.subRow = asubRow, c12.subCol = bsubCol;
    c21.subsr = c.subsr + asubRow, c21.subsc = c.subsc, c21.subRow = asubRow, c21.subCol = bsubCol;
    c22.subsr = c.subsr + asubRow, c22.subsc = c.subsc + bsubCol, c22.subRow = asubRow, c22.subCol = bsubCol;

    /*S1 = B12 - B22, S2 = A11 + A12, S3 = A21 + A22,
      S4 = B21 - B11, S5 = A11 + A22, S6 = B11 + B22,
      S7 = A12 - A22, S8 = B21 + B22, S9 = A11 - A21,
      S10 = B11 + B12*/
    DATATYPE *s1 = new DATATYPE[bsubRow*bsubCol], *s2 = new DATATYPE[asubRow*asubCol], *s3 = new DATATYPE[asubRow*asubCol],
        *s4 = new DATATYPE[bsubRow*bsubCol], *s5 = new DATATYPE[asubRow*asubCol], *s6 = new DATATYPE[bsubRow*bsubCol],
        *s7 = new DATATYPE[asubRow*asubCol], *s8 = new DATATYPE[bsubRow*bsubCol], *s9 = new DATATYPE[asubRow*asubCol],
        *s10 = new DATATYPE[bsubRow*bsubCol];
    SubMat S1(s1, bsubRow, bsubCol), S2(s2, asubRow, asubCol), S3(s3, asubRow, asubCol),
        S4(s4, bsubRow, bsubCol), S5(s5, asubRow, asubCol), S6(s6, bsubRow, bsubCol),
        S7(s7, asubRow, asubCol), S8(s8, bsubRow, bsubCol), S9(s9, asubRow, asubCol),
        S10(s10, bsubRow, bsubCol);
    MatrixMinusAB(b12,b22,S1);
    MatrixAddAB(a11,a12,S2);
    MatrixAddAB(a21,a22,S3);
    MatrixMinusAB(b21,b11,S4);
    MatrixAddAB(a11,a22,S5);
    MatrixAddAB(b11,b22,S6);
    MatrixMinusAB(a12,a22,S7);
    MatrixAddAB(b21,b22,S8);
    MatrixMinusAB(a11,a21,S9);
    MatrixAddAB(b11,b12,S10);
    /*P1 = A11 * S1, P2 = S2 * B22, P3 = S3 * B11, P4 = A22 * S4
      P5 = S5 * S6, P6 = S7 * S8, P7 = S9 * S10 */
    int num = asubRow*bsubCol;
    DATATYPE *p1 = new DATATYPE[num], *p2 = new DATATYPE[num], *p3 = new DATATYPE[num],
        *p4 = new DATATYPE[num], *p5 = new DATATYPE[num], *p6 = new DATATYPE[num],
        *p7 = new DATATYPE[num];
    SubMat P1(p1, asubRow, bsubCol), P2(p2, asubRow, bsubCol), P3(p3, asubRow, bsubCol),
        P4(p4, asubRow, bsubCol), P5(p5, asubRow, bsubCol), P6(p6, asubRow, bsubCol),
        P7(p7, asubRow, bsubCol);
    MatrixMultiplyAB(a11,S1,P1);
    MatrixMultiplyAB(S2,b22,P2);
    MatrixMultiplyAB(S3,b11,P3);
    MatrixMultiplyAB(a22,S4,P4);
    MatrixMultiplyAB(S5,S6,P5);
    MatrixMultiplyAB(S7,S8,P6);
    MatrixMultiplyAB(S9,S10,P7);
    /*C11 = P5 + P4 - P2 + P6
    C12 = P1 + P2
    C21 = P3 + P4
    C22 = P5 + P1 - P3 - P7*/
    MatrixAddAB(P5, P4, c11);
    MatrixMinusAB(c11, P2, c11);
    MatrixAddAB(c11, P6, c11);
    MatrixAddAB(P1,P2,c12);
    MatrixAddAB(P3,P4,c21);
    MatrixAddAB(P5,P1,c22);
    MatrixMinusAB(c22,P3,c22);
    MatrixMinusAB(c22,P7,c22);
    delete[] s1, delete[] s2, delete[] s3, delete[] s4, delete[] s5, delete[] s6, delete[] s7, delete[] s8, delete[] s9, delete[] s10;
    delete[] p1, delete[] p2, delete[] p3, delete[] p4, delete[] p5, delete[] p6, delete[] p7;
}
//package the recursive function to make the input parameters in a common 2-D array form
void MatrixMultipy(int* a, int arow, int acol, int* b, int brow, int bcol, int* c)
{
    SubMat A(a, arow, acol), B(b, brow, bcol),C(c, arow, bcol);
    MatrixMultiplyAB(A,B,C);
}

使用一个比较大的矩阵测试:

void main()
{
    int a[16][8] = { { 1, 1, 1, 1, 2, 66, 456, 4 },{ 1,2,2,3,4,7,9,10 },{ 1,2,2,3,4,4,9,10 },{ 1,56,2,3,4,4,9,10 },{ 1, 1, 1, 1, 2, 66, 456, 4 },{ 1,2,2,3,4,7,9,10 },{ 1,2,2,3,4,4,9,10 },{ 1,56,2,3,4,4,9,10 }, { 1, 1, 1, 1, 2, 66, 456, 4 },{ 1,2,2,3,4,7,9,10 },{ 1,2,2,3,4,4,9,10 },{ 1,56,2,3,4,4,9,10 },{ 1, 1, 1, 1, 2, 66, 456, 4 },{ 1,2,2,3,4,7,9,10 },{ 1,2,2,3,4,4,9,10 },{ 1,56,2,3,4,4,9,10 } };
    int b[8][4] = { { 1,5,6,1 },{ 3,22, 1, 1 },{ 3,5,0,3 },{ 4,24,11,5 },
    { 3,5,67,3 },{ 7,5,67,3 },{ 9,5,67,3 },{ 3,5,27,3 } };
    int c[16][4];
    MatrixMultipy(&a[0][0], 16, 8, &b[0][0], 8, 4, &c[0][0]);
    for (int i = 0; i < 16;i++)
    {
        for (int j = 0; j < 4;j++)
        {
            cout << c[i][j] << "\t";
        }
        cout << endl;
    }
    system("pause");
}

测试结果:
这里写图片描述

使用matlab验证一下:
这里写图片描述这里写图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值