《算法导论》:矩阵乘法的Strassen算法代码实现

一、原文及伪代码

第四章-矩阵乘法的Strassen算法

SQUARE-MATRIX-MULTIPLY-RECURSIVE(A,B)
1 n = A.rows                                   //A的行数
2 let C be a new n*n matrix                    //让C变成新的n*n矩阵
3 if n == 1
4     c11 = a11 * b11
5 else partition A,B,and C as in equations     //将三个矩阵各自分成4个部分
      //分别求出四个元素
6     C11 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B11) 
         + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B21)
7     C12 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B12) 
         + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B22)
8     C21 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B11) 
         + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B21)
9     C22 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B12) 
         + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B22)
10 return C
Strassen()
let C be a new n*n matrix
if A.row == 1:
    C = A * B
else partition A,B,and C //步骤1:将四个矩阵各自分为四部分
    //步骤2:计算10个S
    S1=B12-B22
    S2=A11-A12
    S3=A21+A22
    S4=B21-B11
    S5=A11+A22
    S6=B11+B22
    S7=A12-A22
    S8=B21+B22
    S9=A11-A21
    S10=B11+B12
    //步骤3:递归计算7个矩阵积
    P1=Strassen(A11,S1)
    P2=Strassen(A11,B22)
    P3=Strassen(S3,B11)
    P4=Strassen(A22,S4)
    P5=Strassen(S5,S6)
    P6=Strassen(S7,S8)
    P7=Strassen(S9,S10)
    //步骤4:不同Pi的加减运算
    C11=P5+P4-P2+P6
    C12=P1+P2
    C21=P3+P4
    C22=P5+P1-P3-P7
    return C

二、C++代码

#include <iostream>
#include <Windows.h>
using namespace std;
template<typename T>
class Strassen_class {
public:
    void ADD(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize);
    void SUB(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize);
    void MUL(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize);//朴素算法实现
    void FillMatrix(T** MatrixA, T** MatrixB, int length);//A,B矩阵赋值
    void PrintMatrix(T** MatrixA, int MatrixSize);//打印矩阵
    void Strassen(int N, T** MatrixA, T** MatrixB, T** MatrixC);//Strassen算法实现
};
//矩阵相加
template<typename T>
void Strassen_class<T>::ADD(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize)
{
    for (int i = 0; i < MatrixSize; i++)
    {
        for (int j = 0; j < MatrixSize; j++)
        {
            MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];
        }
    }
}
//矩阵相减
template<typename T>
void Strassen_class<T>::SUB(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize)
{
    for (int i = 0; i < MatrixSize; i++)
    {
        for (int j = 0; j < MatrixSize; j++)
        {
            MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];
        }
    }
}
//普通的矩阵乘法
template<typename T>
void Strassen_class<T>::MUL(T** MatrixA, T** MatrixB, T** MatrixResult, int MatrixSize)
{
    for (int i = 0; i < MatrixSize; i++)
    {
        for (int j = 0; j < MatrixSize; j++)
        {
            MatrixResult[i][j] = 0;
            for (int k = 0; k < MatrixSize; k++)
            {
                MatrixResult[i][j] = MatrixResult[i][j] + MatrixA[i][k] * MatrixB[k][j];
            }
        }
    }
}
//A、B矩阵赋值
template<typename T>
void Strassen_class<T>::FillMatrix(T** MatrixA, T** MatrixB, int length)
{
    for (int row = 0; row < length; row++)
    {
        for (int column = 0; column < length; column++)
        {
            //给矩阵里赋值0到4的随机数
            MatrixB[row][column] = (MatrixA[row][column] = rand() % 5);
        }
    }
}
//打印矩阵
template<typename T>
void Strassen_class<T>::PrintMatrix(T** MatrixA, int MatrixSize)
{
    cout << endl;
    for (int row = 0; row < MatrixSize; row++)
    {
        for (int column = 0; column < MatrixSize; column++)
        {
            cout << MatrixA[row][column] << "\t";
            if ((column + 1) % ((MatrixSize)) == 0)
                cout << endl;
        }
    }
    cout << endl;
}

//Strassen算法
template<typename T>
void Strassen_class<T>::Strassen(int N, T * *MatrixA, T * *MatrixB, T * *MatrixC)
{

    int HalfSize = N / 2;
    int newSize = N / 2;
    //当不能分成4个4*4的数组时,我们就采用正常的办法
    if (N <= 64)    
    {
        MUL(MatrixA, MatrixB, MatrixC, N);
    }
    else
    {
        //创建多个二维数组
        T** A11; T** A12; T** A21; T** A22;
        T** B11; T** B12; T** B21; T** B22;
        T** C11; T** C12; T** C21; T** C22;
        T** M1; T** M2; T** M3; T** M4;
        T** M5; T** M6; T** M7;
        T** AResult; T** BResult;
        //创建一个一维数组的指针,用于寻找首地址
        A11 = new T * [newSize];
        A12 = new T * [newSize];
        A21 = new T * [newSize];
        A22 = new T * [newSize];

        B11 = new T * [newSize];
        B12 = new T * [newSize];
        B21 = new T * [newSize];
        B22 = new T * [newSize];

        C11 = new T * [newSize];
        C12 = new T * [newSize];
        C21 = new T * [newSize];
        C22 = new T * [newSize];

        M1 = new T * [newSize];
        M2 = new T * [newSize];
        M3 = new T * [newSize];
        M4 = new T * [newSize];
        M5 = new T * [newSize];
        M6 = new T * [newSize];
        M7 = new T * [newSize];

        AResult = new T * [newSize];
        BResult = new T * [newSize];

        int newLength = newSize;    //N/2长度

        //在上面一维数组的基础上,分别在每一行再创建一个一维数组的指针,从而实现一个二维数组
        for (int i = 0; i < newSize; i++)
        {
            A11[i] = new T[newLength];
            A12[i] = new T[newLength];
            A21[i] = new T[newLength];
            A22[i] = new T[newLength];

            B11[i] = new T[newLength];
            B12[i] = new T[newLength];
            B21[i] = new T[newLength];
            B22[i] = new T[newLength];

            C11[i] = new T[newLength];
            C12[i] = new T[newLength];
            C21[i] = new T[newLength];
            C22[i] = new T[newLength];

            M1[i] = new T[newLength];
            M2[i] = new T[newLength];
            M3[i] = new T[newLength];
            M4[i] = new T[newLength];
            M5[i] = new T[newLength];
            M6[i] = new T[newLength];
            M7[i] = new T[newLength];

            AResult[i] = new T[newLength];
            BResult[i] = new T[newLength];
        }
        //将输入的数组四等分成N/2*N/2的数组,将A和B中的数组各自赋值给自己的四个分支数组
        for (int i = 0; i < N / 2; i++)
        {
            for (int j = 0; j < N / 2; j++)
            {
                A11[i][j] = MatrixA[i][j];
                A12[i][j] = MatrixA[i][j + N / 2];
                A21[i][j] = MatrixA[i + N / 2][j];
                A22[i][j] = MatrixA[i + N / 2][j + N / 2];

                B11[i][j] = MatrixB[i][j];
                B12[i][j] = MatrixB[i][j + N / 2];
                B21[i][j] = MatrixB[i + N / 2][j];
                B22[i][j] = MatrixB[i + N / 2][j + N / 2];
            }
        }

        //计算7个矩阵
        //M1=A11(B12-B22)  
        SUB(B12, B22, BResult, HalfSize);     
        Strassen(HalfSize, A11, BResult, M1);

        //M2=(A11+A12)B22 
        ADD(A11, A12, AResult, HalfSize);    
        Strassen(HalfSize, AResult, B22, M2);
        
        //M3=(A21+A22)B11  
        ADD(A21, A22, AResult, HalfSize); 
        Strassen(HalfSize, AResult, B11, M3);

        //M4=A22(B21-B11)    
        SUB(B21, B11, BResult, HalfSize); 
        Strassen(HalfSize, A22, BResult, M4);

        //M5=(A11+A22)(B11+B22)
        ADD(A11, A22, AResult, HalfSize);
        ADD(B11, B22, BResult, HalfSize);    
        Strassen(HalfSize, AResult, BResult, M5); 
       
        //M6=(A12-A22)(B21+B22) 
        SUB(A12, A22, AResult, HalfSize);
        ADD(B21, B22, BResult, HalfSize);     
        Strassen(HalfSize, AResult, BResult, M6);

        //M7=(A11-A21)(B11+B12)
        SUB(A11, A21, AResult, HalfSize);
        ADD(B11, B12, BResult, HalfSize);    
        Strassen(HalfSize, AResult, BResult, M6);    
 

        //C11 = M5 + M4 - M2 + M6;
        ADD(M5, M4, AResult, HalfSize);
        SUB(M6, M2, BResult, HalfSize);
        ADD(AResult, BResult, C11, HalfSize);

        //C12 = M1 + M1;
        ADD(M1, M2, C12, HalfSize);

        //C21 = M3 + M4;
        ADD(M3, M4, C21, HalfSize);

        //C22 = M5 + M1 - M3 - M7;
        ADD(M5, M1, AResult, HalfSize);
        ADD(M7, M3, BResult, HalfSize);
        SUB(AResult, BResult, C22, HalfSize);

        //组合小矩阵到一个大矩阵
        for (int i = 0; i < N / 2; i++)
        {
            for (int j = 0; j < N / 2; j++)
            {
                MatrixC[i][j] = C11[i][j];
                MatrixC[i][j + N / 2] = C12[i][j];
                MatrixC[i + N / 2][j] = C21[i][j];
                MatrixC[i + N / 2][j + N / 2] = C22[i][j];
            }
        }

        // 释放矩阵内存空间
        for (int i = 0; i < newLength; i++)
        {
            delete[] A11[i]; delete[] A12[i]; delete[] A21[i];delete[] A22[i];
            delete[] B11[i]; delete[] B12[i]; delete[] B21[i];delete[] B22[i];
            delete[] C11[i]; delete[] C12[i]; delete[] C21[i];delete[] C22[i];
            delete[] M1[i]; delete[] M2[i]; delete[] M3[i]; delete[] M4[i];
            delete[] M5[i]; delete[] M6[i]; delete[] M7[i];
            delete[] AResult[i]; delete[] BResult[i];
         }
        delete[] A11; delete[] A12; delete[] A21; delete[] A22;
        delete[] B11; delete[] B12; delete[] B21; delete[] B22;
        delete[] C11; delete[] C12; delete[] C21; delete[] C22;
        delete[] M1; delete[] M2; delete[] M3; delete[] M4; 
        delete[] M5;delete[] M6; delete[] M7;
        delete[] AResult;delete[] BResult;
    }
}

int main()
{
    Strassen_class<int> stra;//定义Strassen_class类对象
    int MatrixSize = 0;

    int** MatrixA;    //存放矩阵A
    int** MatrixB;    //存放矩阵B
    int** MatrixC;    //存放结果矩阵
    cout << "\n请输入矩阵大小(必须是2的幂指数值(例如:32,64,512,..): ";
    cin >> MatrixSize;

    int N = MatrixSize;//for readiblity.

    //申请内存
    MatrixA = new int* [MatrixSize];
    MatrixB = new int* [MatrixSize];
    MatrixC = new int* [MatrixSize];
    //申请空间
    for (int i = 0; i < MatrixSize; i++)
    {
        MatrixA[i] = new int[MatrixSize];
        MatrixB[i] = new int[MatrixSize];
        MatrixC[i] = new int[MatrixSize];
    }
    stra.FillMatrix(MatrixA, MatrixB, MatrixSize);  //矩阵赋值
    stra.Strassen(N, MatrixA, MatrixB, MatrixC); //strassen矩阵相乘算法
    cout << "\n矩阵运算结果... \n";
    stra.PrintMatrix(MatrixC, MatrixSize);
    return 0;
}
  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KeepCoding♪Toby♪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值