Strassen矩阵乘法

#include <iostream.h>

const int N = 8;    //常量N用来定义矩阵的大小

template <typename T>
void STRASSEN(int n, T A[][N], T B[][N], T C[][N]);

template <typename T>
void input(int n, T p[][N]);

template <typename T>
void output(int n, T C[][N]);    //函数声明部分

void main()
{
    int A[N][N], B[N][N], C[N][N];    //定义三个矩阵A,B,C
   
    for (int i = 0; i < 8; i++)
    {
        for (int j = 0; j < 8; j++)
        {
            A[i][j] = i * j;
            B[i][j] = i * j;
        }
    }

    STRASSEN(N,A,B,C);    //调用STRASSEN函数计算
   
    output(N,C);    //输出计算结果
}

template <typename T>
void input(int n, T p[][N])    //矩阵输入函数
{
    int i, j;
   
    for (i = 0; i < n; i++)
    {
        cout << "请输入第" << i+1 << "行" << endl;
        for (j = 0; j < n; j++)
        {
            cin >> p[i][j];
        }
    }
}

template <typename T>
void output(int n, T C[][N])    //矩阵输出函数
{
    int i, j;
    cout << "输出矩阵:" << endl;
    for (i = 0; i < n; i++)
    {
        for ( j = 0; j < n; j++)
            cout << C[i][j] << ' ';
        cout << endl;
    }
}

template <typename T>
void MATRIX_MULTIPLY(T A[][N], T B[][N], T C[][N])    //按通常的矩阵乘法计算C=AB的子算法(仅做2阶)
{
    int i, j, t;
    for (i = 0; i < 2; i++)    //计算A*B-->C
    {
        for (j = 0; j < 2; j++)
        {
            C[i][j] = 0;    //计算完一个C[i][j],C[i][j]应重新赋值为零
            for (t = 0; t < 2; t++)
            {
                C[i][j] = C[i][j] + A[i][t] * B[t][j];
            }
        }
    }
}

template <typename T>
void MATRIX_ADD(int n, T X[][N], T Y[][N], T Z[][N])    //矩阵加法函数X+Y—>Z
{
    int i, j;
    for (i = 0; i < n; i++)
    {
        for (j = 0; j < n; j++)
        {
            Z[i][j] = X[i][j] + Y[i][j];
        }
    }
}

template <typename T>
void MATRIX_SUB(int n, T X[][N], T Y[][N], T Z[][N])    //矩阵减法函数X-Y—>Z
{
    int i, j;
    for (i = 0; i < n; i++)
    {
        for (j = 0; j < n; j++)
        {
            Z[i][j] = X[i][j] - Y[i][j];
        }
    }
}

//    fullfill C = A * B
template <typename T>
void STRASSEN(int n, T A[][N], T B[][N], T C[][N])    //STRASSEN函数(递归)
{
    T A11[N][N], A12[N][N], A21[N][N], A22[N][N];
    T B11[N][N], B12[N][N], B21[N][N], B22[N][N];
    T C11[N][N], C12[N][N], C21[N][N], C22[N][N];
    T M1[N][N], M2[N][N], M3[N][N], M4[N][N], M5[N][N], M6[N][N], M7[N][N];
    T AA[N][N], BB[N][N], MM1[N][N], MM2[N][N];
   
    int i, j;    //,x;
   
    if (n == 2)
    {
        MATRIX_MULTIPLY(A, B, C);    //按通常的矩阵乘法计算C=AB的子算法(仅做2阶)
    }
    else
    {
        for (i = 0; i < n / 2; i++)
        {
            for ( j = 0; j < n / 2; j++)
            {
                A11[i][j] = A[i][j];
                A12[i][j] = A[i][j + n / 2];
                A21[i][j] = A[i + n / 2][j];
                A22[i][j] = A[i + n / 2][j + n / 2];
                B11[i][j] = B[i][j];
                B12[i][j] = B[i][j + n / 2];
                B21[i][j] = B[i + n / 2][j];
                B22[i][j] = B[i + n / 2][j + n / 2];
                //将矩阵A和B式分为四块
            }
        }
       
        MATRIX_SUB (n / 2, B12, B22, BB);
        STRASSEN (n/ 2, A11, BB, M1);    //M1=A11(B12-B22)
       
        MATRIX_ADD (n / 2, A11, A12, AA);
        STRASSEN (n / 2, AA, B22, M2);    //M2=(A11+A12)B22
       
        MATRIX_ADD (n / 2, A21, A22, AA);
        STRASSEN (n / 2, AA, B11, M3);    //M3=(A21+A22)B11
       
        MATRIX_SUB (n / 2, B21, B11, BB);
        STRASSEN (n / 2, A22, BB, M4);    //M4=A22(B21-B11)
       
        MATRIX_ADD (n / 2, A11, A22, AA);
        MATRIX_ADD (n / 2, B11, B22, BB);
        STRASSEN (n / 2, AA, BB, M5);    //M5=(A11+A22)(B11+B22)
       
        MATRIX_SUB (n / 2, A12, A22, AA);
        MATRIX_SUB (n / 2, B21, B22, BB);
        STRASSEN (n / 2, AA, BB, M6);    //M6=(A12-A22)(B21+B22)
       
        MATRIX_SUB (n / 2, A11, A21, AA);
        MATRIX_SUB (n / 2, B11, B12, BB);
        STRASSEN (n / 2, AA, BB, M7);    //M7=(A11-A21)(B11+B12)
        //计算M1,M2,M3,M4,M5,M6,M7(递归部分)
       
        MATRIX_ADD (n / 2, M5, M4, MM1);
        MATRIX_SUB (n / 2, M6, M2, MM2);
        MATRIX_ADD (n / 2, MM1, MM2, C11);    //C11=M5+M4-M2+M6
       
        MATRIX_ADD (n / 2, M1, M2, C12);    //C12=M1+M2
       
        MATRIX_ADD (n / 2, M3, M4, C21);    //C21=M3+M4
       
        MATRIX_ADD (n / 2, M5, M1, MM1);
        MATRIX_ADD (n / 2, M3, M7, MM2);
        MATRIX_SUB (n / 2, MM1, MM2, C22);    //C22=M5+M1-M3-M7
       
        for (i = 0; i < n / 2; i++)
        {
            for (j = 0 ; j < n / 2; j++)
            {
                C[i][j] = C11[i][j];
                C[i][j + n / 2] = C12[i][j];
                C[i + n / 2][j] = C21[i][j];
                C[i + n / 2][j + n / 2] = C22[i][j];
            }    //计算结果送回C[n][n]
        }
    }
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值