Strassen's algorithm to compute matrix multiplication

//
//  main.cpp
//  Strassen
//
//  Created by Longxiang Lyu on 5/24/16.
//  Copyright (c) 2016 Longxiang Lyu. All rights reserved.
//

#include <iostream>
#include <vector>
#include <string>
#include <stdexcept>
#include <math.h>

using namespace std;

void printMatrix(const vector<vector<int>> &matrix)
{
    for (auto row : matrix)
    {
        for (auto elem : row)
            cout << elem << " ";
        cout << endl;
    }
}

void zeroPadding(vector<vector<int>> &matrix)
{
    size_t sz = pow(2, (int)(sqrt(max(matrix.size(), matrix[0].size())) + 1));
    matrix.resize(sz);
    for (size_t i = 0; i < sz; ++i)
    {
        if (!matrix[i].empty())
            matrix[i].resize(sz);
        else
            matrix[i] = vector<int>(sz, 0);
    }
}

void sum(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
    size_t sz = A.size();

    for (int i = 0; i < sz; ++i)
        for (int j = 0; j < sz; ++j)
            ret[i][j] = (A[i][j] + B[i][j]);
}

void subtract(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
    size_t sz = A.size();
    // ret.clear();
    // ret.resize(sz);
    for (int i = 0; i < sz; ++i)
        for (int j = 0; j < sz; ++j)
            ret[i][j] = (A[i][j] - B[i][j]);
}



void strassenHelper(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
    if (A.size() == 1)
    {
        ret[0][0] = A[0][0] * B[0][0];
        return;
    }
    size_t sz = A.size();
    size_t new_sz = sz / 2;
    ret = vector<vector<int>>(sz, vector<int>(sz));

    vector<vector<int>> a11(new_sz), a12(new_sz), a21(new_sz), a22(new_sz), b11(new_sz), b12(new_sz), b21(new_sz), b22(new_sz);
    for (int i = 0; i < new_sz; ++i)
    {
        for (int j = 0; j < new_sz; ++j)
        {
            a11[i].push_back(A[i][j]);
            a12[i].push_back(A[i][j + new_sz]);
            a21[i].push_back(A[i + new_sz][j]);
            a22[i].push_back(A[i + new_sz][j + new_sz]);
            
            b11[i].push_back(B[i][j]);
            b12[i].push_back(B[i][j + new_sz]);
            b21[i].push_back(B[i + new_sz][j]);
            b22[i].push_back(B[i + new_sz][j + new_sz]);
        }
    }
    vector<vector<int>> result1(new_sz, vector<int>(new_sz, 0)), result2(new_sz, vector<int>(new_sz, 0));

    // p1
    vector<vector<int>> p1(new_sz, vector<int>(new_sz, 0));
    sum(a11, a22, result1);
    sum(b11, b22, result2);
    strassenHelper(result1, result2, p1);
    
    // p2
    vector<vector<int>> p2(new_sz, vector<int>(new_sz, 0));
    sum(a21, a22, result1);
    strassenHelper(result1, b11, p2);
    
    // p3
    vector<vector<int>> p3(new_sz, vector<int>(new_sz, 0));
    subtract(b12, b22, result2);
    strassenHelper(a11, result2, p3);
    
    // p4
    vector<vector<int>> p4(new_sz, vector<int>(new_sz, 0));
    subtract(b21, b11, result2);
    strassenHelper(a22, result2, p4);
    
    // p5
    vector<vector<int>> p5(new_sz, vector<int>(new_sz, 0));
    sum(a11, a12, result1);
    strassenHelper(result1, b22, p5);
    
    // p6
    vector<vector<int>> p6(new_sz, vector<int>(new_sz, 0));
    subtract(a21, a11, result1);
    sum(b11, b12, result2);
    strassenHelper(result1, result2, p6);
    
    // p7
    vector<vector<int>> p7(new_sz, vector<int>(new_sz, 0));
    subtract(a12, a22, result1);
    sum(b21, b22, result2);
    strassenHelper(result1, result2, p7);
    
    vector<vector<int>> c11(new_sz, vector<int>(new_sz, 0));
    vector<vector<int>> c12(new_sz, vector<int>(new_sz, 0));
    vector<vector<int>> c21(new_sz, vector<int>(new_sz, 0));
    vector<vector<int>> c22(new_sz, vector<int>(new_sz, 0));

    sum(p3, p5, c12);
    sum(p2, p4, c21);
    
    sum(p1, p4, result1);
    sum(result1, p7, result2);
    subtract(result2, p5, c11);
    
    sum(p1, p3, result1);
    sum(result1, p6, result2);
    subtract(result2, p2, c22);
    
    for (int i = 0; i < new_sz; ++i)
    {
        for (int j = 0; j < new_sz; ++j)
        {
            ret[i][j] = c11[i][j];
            ret[i][j + new_sz] = c12[i][j];
            ret[i + new_sz][j] = c21[i][j];
            ret[i + new_sz][j + new_sz] = c22[i][j];
        }
    }
    
}




void strassen(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
    if (A.empty() || B.empty())
        throw runtime_error("empty matrices");
    if (A[0].size() != B.size())
        throw runtime_error("A's col not equal B's row");
    zeroPadding(A);
    zeroPadding(B);
    strassenHelper(A, B, ret);
}

int main(int argc, const char * argv[]) {
    vector<vector<int>> A{{1, 2, 0}, {1, 2, 3}, {1, 2, 3}};
    vector<vector<int>> B{{1, 0, 1}, {1, 1, 1}, {2, 1, 1}};
    vector<vector<int>> ret(2, vector<int>(2));
    strassen(A, B, ret);
    printMatrix(ret);
    return 0;
}

Reference:

https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值