矩阵乘积计算(Strassen)

23 篇文章 2 订阅
11 篇文章 1 订阅

矩阵乘积计算(Strassen)

问题描述

​ 已知A,B两个矩阵计算其乘积C?

矩阵乘积数学公式:

​ 假设存在两个矩阵A为m×n矩阵,B为k×l矩阵,若需要计算AB则必须n=k,若需要计算BA必须l=m否则无法进行计算,先假定n=k即B为n×l矩阵则AB的结果为一个m×l的矩阵并且该矩阵每个点的元素的值表示为 Cij 则:

这里写图片描述

这里写图片描述

方法一:直接计算

​ 直接利用多重for循环求出相关矩阵对应的点的值即可

//矩阵的数据结构,随机矩阵,非特殊矩阵
struct array
{
    int **data;                 //数据域
    int row;
    int col;
};

/**
 *  初始化矩阵元素,用随机数填充
 *  只为研究算法因此为进行相关的内存检查
 *  flag用来标记是否生成空矩阵,即元素全部为0的矩阵
 */
void init_array(struct array *ptr,const int row,const int col,int flag)
{
    int i = 0,j = 0;
    ptr->data = (int **)malloc(sizeof(int)*row);                                    //??内存分配
    for (i = 0; i < row; i++)
    {
        *(ptr->data + i) = (int*)malloc(sizeof(int)*col);                           //??内存分配
    }

    ptr->col = col;
    ptr->row = row;
    srand(time(NULL));
    for (i = 0; i < row; i++)
    {
        for (j = 0; j < col; j++)
        {
            if (flag)
            {
                ptr->data[i][j] = rand() % ARRAY_PRCE;
            }
            else
            {
                ptr->data[i][j] = 0;
            }
        }
    }
}

/**
 *  打印矩阵元素
 */
void print_array(const struct array *ptr, const char *msg)
{
    int i, j;
    printf("%s\n", msg);
    for (i = 0; i < ptr->row; i++)
    {
        for (j = 0; j < ptr->col; j++)
        {
            printf("%4d", ptr->data[i][j]);
        }

        printf("\n");
    }
}
/**
 *  销毁内存
 */
void delete_array(struct array *ptr)
{
    int i = 0;
    for (i = 0; i < ptr->row; i++)
    {
        free(*(ptr->data + i));
        *(ptr->data + i) = NULL;
    }

    free(ptr->data);
}

/*******************************************************************************************

*******************************************************************************************/

/**
*   矩阵乘法求解  
*   问题描述:已知两个可以进行相乘的矩阵,求的乘积后的结果
*/

/**
*   方法一:暴力直接求解  
*   利用矩阵乘法规则直接进行求解罗列出每个点的值求的最终的矩阵
*/
struct array mult_array(const struct array *ptr1, const struct array *ptr2)
{
    int i = 0;
    int j = 0;
    int k = 0;
    struct array ptr;
    if (ptr1->col != ptr2->row)                         //检查是否符合可以进行乘积的要求
    {
        return;
    }

    init_array(&ptr, ptr1->row, ptr2->col, 0);
    for (i = 0; i < ptr.row;i ++)
    {
        for (j = 0; j < ptr.col; j++)
        {
            for (k = 0; k < ptr1->col; k++)
            {
                ptr.data[i][j] += ptr1->data[i][k] * ptr2->data[k][j];
            }
        }
    }

    return ptr;
}

执行效果

这里写图片描述

时间复杂度为O( n3 )

方法二:分治算法

​ 将矩阵分解为一个个小矩阵进行计算然后将计算结果合并得到相关的结果。源于矩阵服从分配率和结合律,并不支持交换律。

​ 三个矩阵本身就可以写成下面的格式

这里写图片描述

​ 那么相关的计算可以写成

这里写图片描述

​ 同理A11等一些子矩阵也可以写成相关的子矩阵,就这样将矩阵不断分解为小矩阵进行计算,最后归并为一个矩阵。

​ 时间复杂度为O( n3 )

/**
 *  方法二:利用分治思想进行求解
 *  存在问题无法解决不同类型的矩阵的问题,要求矩阵的行列必须为2的n次方,若不符合要求可以使用
 *  补0来构造相关的矩阵
 */
Matrix* Matrix::merge_calc(const Matrix& x)
{
    if (x.row == 1)             //当前的矩阵为单个的元素
    {
        Matrix *ptr = new Matrix(x.row, x.col);
        ptr->clear((this->getElem(0, 0))*(x.getElem(0, 0)));
        return ptr;
    }

    //将第一个矩阵分解为四个子矩阵
    Matrix A11(0,0,row/2,col/2,*this);
    Matrix A12(row / 2, 0, row, col / 2, *this);
    Matrix A21(0, col / 2, row / 2, col, *this);
    Matrix A22(row / 2, col / 2, row, col, *this);
    //将第二个矩阵分解为四个子矩阵
    Matrix B11(0, 0, row / 2, col / 2, x);
    Matrix B12(row / 2, 0, row, col / 2, x);
    Matrix B21(0, col / 2, row / 2, col, x);
    Matrix B22(row / 2, col / 2, row, col, x);

    Matrix *C11 = Matrix::add(A11.merge_calc(B11), A12.merge_calc(B21));
    Matrix *C12 = Matrix::add(A11.merge_calc(B12), A12.merge_calc(B22));
    Matrix *C21 = Matrix::add(A21.merge_calc(B11), A22.merge_calc(B21));
    Matrix *C22 = Matrix::add(A21.merge_calc(B12), A22.merge_calc(B22));

    //将C11,C12,C21,C22合并为一个完整的矩阵
    Matrix* ptr = Matrix::merge(C11, C12, C21, C22);

    return ptr;
}

方法三:Strassen算法

​ Strassen算法同样是使用分治的思想解决问题,只不过,不同的是当矩阵的阶很大时就会采取一个递推式进行计算相关递推式为:

                            S1 = B12 - B22
                            S2 = A11 + A12
                            S3 = A21 + A22
                            S4 = B21 - B11
                            S5 = A11 + A22
                            S6 = B11 + B22
                            S7 = A12 - A22
                            S8 = B21 + B22
                            S9 = A11 - A21
                            S10 = B11 + B12 
                            P1 = A11 * S1
                            P2 = S2 * B22
                            P3 = S3 * B11
                            P4 = A22 * S4
                            P5 = S5 * S6
                            P6 = S7 * S8
                            P7 = S9 * S10
                            C11 = P5 + P4 - P2 + P6
                            C12 = P1 + P2
                            C21 = P3 + P4
                            C22 = P5 + P1 - P3 - P7

​ 其中A11,A12,A21,A22和B11,B12,B21,B22分别为两个乘数A和B矩阵的四个子矩阵。C11,C12,C21,C22为最终的结果C矩阵的四个子矩阵。该递推式是被数学家证明过的。

​ 该算法的效率为O( n(log27) ),但是相对来说额外空间的使用也是很多的。

Matrix* Matrix::strassen_calc(const Matrix& x)
{
    if (x.row < 2)
    {
        return this->force_calc(x);
    }

    //将第一个矩阵分解为四个子矩阵
    Matrix A11(0, 0, row / 2, col / 2, *this);
    Matrix A12(row / 2, 0, row, col / 2, *this);
    Matrix A21(0, col / 2, row / 2, col, *this);
    Matrix A22(row / 2, col / 2, row, col, *this);
    //将第二个矩阵分解为四个子矩阵
    Matrix B11(0, 0, row / 2, col / 2, x);
    Matrix B12(row / 2, 0, row, col / 2, x);
    Matrix B21(0, col / 2, row / 2, col, x);
    Matrix B22(row / 2, col / 2, row, col, x);

    Matrix* S1 = B12 - B22;
    Matrix* S2 = A11 + A12;
    Matrix* S3 = A21 + A22;
    Matrix* S4 = B21 - B11;
    Matrix* S5 = A11 + A22;
    Matrix* S6 = B11 + B22;
    Matrix* S7 = A12 - A22;
    Matrix* S8 = B21 + B22;
    Matrix* S9 = A11 - A21;
    Matrix* S10 = B11 + B12;

    Matrix* P1 = B12 - B22;
    Matrix* P2 = B12 - B22;
    Matrix* P3 = B12 - B22;
    Matrix* P4 = B12 - B22;
    Matrix* P5 = B12 - B22;
    Matrix* P6 = B12 - B22;
    Matrix* P7 = B12 - B22;

    P1 = A11.strassen_calc(*S1);
    P2 = S2->strassen_calc(B22);
    P3 = S3->strassen_calc(B11);
    P4 = A22.strassen_calc(*S4);
    P5 = S5->strassen_calc(*S6);
    P6 = S7->strassen_calc(*S8);
    P7 = S9->strassen_calc(*S10);

    Matrix *C11 = Matrix::sub(Matrix::add(P5, P4), Matrix::sub(P2, P6));
    Matrix *C12 = Matrix::add(P1, P2);
    Matrix *C21 = Matrix::add(P3, P4);
    Matrix *C22 = Matrix::sub(Matrix::add(P5, P1), Matrix::add(P3, P7));

    return Matrix::merge(C11,C12,C21,C22);
}

执行效果:

这里写图片描述

完整的代码

//Matrix.h
#pragma once
#ifndef _MATRIX_H_
#define _MATRIX_H_

#include <iostream>
#include <vector>
using std::vector;
#define VISE 5
#define GATE 16                 //用来限定使用哪种算法进行计算

#include <cstdlib>
#include <ctime>

typedef int type;
class Matrix
{
private:
    int row;                            //行
    int col;                            //列
    vector<vector<type>> data;          //数据
public:
    Matrix(int row, int col) :data(row),row(row),col(col)                   //矩阵数据生成利用随机数进行生成
    {
        for (int i = 0; i < row; i++)
        {
            data[i].resize(col);
        }

        srand(time(0));
        for (int i = 0; i < row; i++)
        {
            for (int j = 0; j < col; j++)
            {
                data[i][j] = rand() % VISE;
            }
        }
    }

    Matrix(int row1, int col1, int row2, int col2, const Matrix& x) :row(row2 - row1), col(col2 - col1),data(row)
    {
        for (int i = 0; i < row; i++)
        {
            data[i].resize(col);
        }

        for (int i = 0; i < row; i++)
        {
            for (int j = 0; j < col; j++)
            {
                data[i][j] = x.getElem(col1 + i, row1 + j);
            }
        }
    }

    Matrix(const Matrix& x)
    {
        *this = x;
    }

    //相关算数运算操作
    Matrix* operator+(const Matrix&);
    Matrix* operator-(const Matrix&);
    Matrix* operator*(const Matrix&);
    static Matrix* add(const Matrix*, const Matrix*);                               //+
    static Matrix* sub(const Matrix*, const Matrix*);                               //-
    static Matrix* merge(const Matrix*, const Matrix*,const Matrix*,const Matrix*); //将四个子矩阵合并为一个矩阵
    //获取矩阵的相关元素
    vector<type> operator[](const int);             //取得row
    type getElem(const int,const int) const;        //获取相关节点的数据
    void setElem(const int, const int, type);       //设置节点的数据   

    //计算乘法的算法
    Matrix* force_calc(const Matrix&);              //直接暴力求解
    Matrix* merge_calc(const Matrix&);              //分治求解
    Matrix* strassen_calc(const Matrix&);           //Strassen算法

    void show();                                    //打印矩阵
    bool isSimilar(const Matrix& x);                //行列相同即为同类型矩阵
    void clear(type);                               //设置矩阵中所有的元素为同一个指定的值                                

    ~Matrix();
};

#endif
//_MATRIX_H_
//Matrix.cpp
#include "Matrix.h"

Matrix* Matrix::operator*(const Matrix& x)
{
    if (x.row != this->col)
    {
        return nullptr;
    }

    if (x.row < VISE && x.col < VISE && row < VISE && col < VISE)
    {
        return this->merge_calc(x);
    }

    return this->strassen_calc(x);
}

/**
 *  方法一:暴力直接求解问题
 *  时间复杂度为O(n^3)
 */
Matrix* Matrix::force_calc(const Matrix& x)
{
    if (x.row != this->col)                                             //行列不同无法进行乘法,可以进行补零将相关矩阵填充为可使用的矩阵
    {                                                                   //这里不进行相关的编写
        return nullptr;
    }

    Matrix *ptr = new Matrix(row, x.col);
    ptr->clear(0);
    for (int i = 0; i < row; i++)
    {
        for (int j = 0; j < x.col; j++)
        {
            for (int k = 0; k < col; k++)
            {
                ptr->setElem(i, j, ptr->getElem(i, j) + getElem(i, k) * x.getElem(k, j));
            }
        }
    }

    return ptr;
}

void Matrix::clear(type cur = 0)
{
    for (int i = 0; i < row; i++)
    {
        for (int j = 0; j < col; j++)
        {
            data[i][j] = cur;
        }
    }
}

/**
 *  方法二:利用分治思想进行求解
 *  存在问题无法解决不同类型的矩阵的问题,要求矩阵的行列必须为2的n次方,若不符合要求可以使用
 *  补0来构造相关的矩阵
 */
Matrix* Matrix::merge_calc(const Matrix& x)
{
    if (x.row == 1)             //当前的矩阵为单个的元素
    {
        Matrix *ptr = new Matrix(x.row, x.col);
        ptr->clear((this->getElem(0, 0))*(x.getElem(0, 0)));
        return ptr;
    }

    //将第一个矩阵分解为四个子矩阵
    Matrix A11(0,0,row/2,col/2,*this);
    Matrix A12(row / 2, 0, row, col / 2, *this);
    Matrix A21(0, col / 2, row / 2, col, *this);
    Matrix A22(row / 2, col / 2, row, col, *this);
    //将第二个矩阵分解为四个子矩阵
    Matrix B11(0, 0, row / 2, col / 2, x);
    Matrix B12(row / 2, 0, row, col / 2, x);
    Matrix B21(0, col / 2, row / 2, col, x);
    Matrix B22(row / 2, col / 2, row, col, x);

    Matrix *C11 = Matrix::add(A11.merge_calc(B11), A12.merge_calc(B21));
    Matrix *C12 = Matrix::add(A11.merge_calc(B12), A12.merge_calc(B22));
    Matrix *C21 = Matrix::add(A21.merge_calc(B11), A22.merge_calc(B21));
    Matrix *C22 = Matrix::add(A21.merge_calc(B12), A22.merge_calc(B22));

    //将C11,C12,C21,C22合并为一个完整的矩阵
    Matrix* ptr = Matrix::merge(C11, C12, C21, C22);

    return ptr;
}

Matrix* Matrix::strassen_calc(const Matrix& x)
{
    if (x.row < 2)
    {
        return this->force_calc(x);
    }

    //将第一个矩阵分解为四个子矩阵
    Matrix A11(0, 0, row / 2, col / 2, *this);
    Matrix A12(row / 2, 0, row, col / 2, *this);
    Matrix A21(0, col / 2, row / 2, col, *this);
    Matrix A22(row / 2, col / 2, row, col, *this);
    //将第二个矩阵分解为四个子矩阵
    Matrix B11(0, 0, row / 2, col / 2, x);
    Matrix B12(row / 2, 0, row, col / 2, x);
    Matrix B21(0, col / 2, row / 2, col, x);
    Matrix B22(row / 2, col / 2, row, col, x);

    Matrix* S1 = B12 - B22;
    Matrix* S2 = A11 + A12;
    Matrix* S3 = A21 + A22;
    Matrix* S4 = B21 - B11;
    Matrix* S5 = A11 + A22;
    Matrix* S6 = B11 + B22;
    Matrix* S7 = A12 - A22;
    Matrix* S8 = B21 + B22;
    Matrix* S9 = A11 - A21;
    Matrix* S10 = B11 + B12;

    Matrix* P1 = B12 - B22;
    Matrix* P2 = B12 - B22;
    Matrix* P3 = B12 - B22;
    Matrix* P4 = B12 - B22;
    Matrix* P5 = B12 - B22;
    Matrix* P6 = B12 - B22;
    Matrix* P7 = B12 - B22;

    P1 = A11.strassen_calc(*S1);
    P2 = S2->strassen_calc(B22);
    P3 = S3->strassen_calc(B11);
    P4 = A22.strassen_calc(*S4);
    P5 = S5->strassen_calc(*S6);
    P6 = S7->strassen_calc(*S8);
    P7 = S9->strassen_calc(*S10);

    Matrix *C11 = Matrix::sub(Matrix::add(P5, P4), Matrix::sub(P2, P6));
    Matrix *C12 = Matrix::add(P1, P2);
    Matrix *C21 = Matrix::add(P3, P4);
    Matrix *C22 = Matrix::sub(Matrix::add(P5, P1), Matrix::add(P3, P7));

    return Matrix::merge(C11,C12,C21,C22);
}

/**
 *  将四个子矩阵合并为一个完整的矩阵
 *  也可以使用分治思想进行解决,以后可能会添加相关的功能
 */
Matrix* Matrix::merge(const Matrix* p1, const Matrix* p2,const Matrix* p3, const Matrix* p4)
{
    //不符合可以进行合并的条件
    if (!(p1->row == p2->row && p2->col == p4->col && p4->row == p3->row && p1->col == p3->col))
    {
        return nullptr;
    }

    Matrix* ptr = new Matrix(p1->row + p3->row, p2->col + p1->col);
    ptr->clear(0);
    //重新装值
    for (int i = 0; i < p1->row; i++)
    {
        for (int j = 0; j < p1->col; j++)
        {
            ptr->setElem(i, j, p1->getElem(i, j));
        }
    }

    for (int i = 0; i < p2->row; i++)
    {
        for (int j = 0; j < p2->col; j++)
        {
            ptr->setElem(i, j + p1->col, p2->getElem(i, j));
        }
    }

    for (int i = 0; i < p3->row; i++)
    {
        for (int j = 0; j < p3->col; j++)
        {
            ptr->setElem(i + p1->row, j, p3->getElem(i, j));
        }
    }

    for (int i = 0; i < p4->row; i++)
    {
        for (int j = 0; j < p4->col; j++)
        {
            ptr->setElem(p1->row + i, p1->col + j, p4->getElem(i, j));
        }
    }

    return ptr;
}

Matrix* Matrix::sub(const Matrix* p1, const Matrix* p2)
{
    if (!(p1->col == p2->col && p1->row == p2->row))
    {
        return nullptr;
    }

    Matrix *ptr = new Matrix(p1->row, p1->col);
    for (int i = 0; i < p1->row; i++)
    {
        for (int j = 0; j < p1->col; j++)
        {
            ptr->setElem(i, j, (p1->getElem(i, j) - p2->getElem(i, j)));
        }
    }

    return ptr;
}

Matrix* Matrix::add(const Matrix* p1, const Matrix* p2)
{
    if (!(p1->col == p2->col && p1->row == p2->row))
    {
        return nullptr;
    }

    Matrix *ptr = new Matrix(p1->row, p1->col);
    for (int i = 0; i < p1->row; i++)
    {
        for (int j = 0; j < p1->col; j++)
        {
            ptr->setElem(i, j, (p1->getElem(i, j) + p2->getElem(i, j)));
        }
    }

    return ptr;
}

Matrix* Matrix::operator+(const Matrix& x)
{
    if (!isSimilar(x))
    {
        return nullptr;
    }

    Matrix *ptr = new Matrix(x.row, x.col);                             //内存需要释放
    for (int i = 0; i < row; i++)
    {
        for (int j = 0; j < col; j++)
        {
            ptr->setElem(i, j, this->getElem(i, j) + x.getElem(i, j));
        }
    }

    return ptr;
}

Matrix* Matrix::operator-(const Matrix& x)
{
    if (!isSimilar(x))
    {
        return nullptr;
    }

    Matrix *ptr = new Matrix(x.row, x.col);                             //内存需要释放
    for (int i = 0; i < row; i++)
    {
        for (int j = 0; j < col; j++)
        {
            ptr->setElem(i, j, this->getElem(i, j) - x.getElem(i, j));
        }
    }

    return ptr;
}

vector<type> Matrix::operator[](const int row)
{
    return data[row];
}

type Matrix::getElem(int row, int col)const
{
    return this->data[row][col];
}

void Matrix::setElem(int row, int col, type cur)
{
    this->data[row][col] = cur;
}

void Matrix::show()
{
    for (int i = 0; i < row; i++)
    {
        if (i == 0)
        {
            std::cout << "┏";
        }
        else if (i == row - 1)
        {
            std::cout << "┗";
        }
        else
        {
            std::cout << "┃";
        }

        for (int j = 0; j < col; j++)
        {
            std::cout.width(4);
            std::cout << data[i][j];
        }

        if (i == 0)
        {
            std::cout << "   ┓";
        }
        else if (i == row - 1)
        {
            std::cout << "   ┛";
        }
        else
        {
            std::cout << "   ┃";
        }

        std::cout << std::endl;
    }
}

bool Matrix::isSimilar(const Matrix& x)
{
    return x.row == this->row && this->col == x.col;
}

Matrix::~Matrix()
{
    this->row = 0;
    this->col = 0;
}

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值