//
// main.cpp
// Strassen
//
// Created by Longxiang Lyu on 5/24/16.
// Copyright (c) 2016 Longxiang Lyu. All rights reserved.
//
#include <iostream>
#include <vector>
#include <string>
#include <stdexcept>
#include <math.h>
using namespace std;
void printMatrix(const vector<vector<int>> &matrix)
{
for (auto row : matrix)
{
for (auto elem : row)
cout << elem << " ";
cout << endl;
}
}
void zeroPadding(vector<vector<int>> &matrix)
{
size_t sz = pow(2, (int)(sqrt(max(matrix.size(), matrix[0].size())) + 1));
matrix.resize(sz);
for (size_t i = 0; i < sz; ++i)
{
if (!matrix[i].empty())
matrix[i].resize(sz);
else
matrix[i] = vector<int>(sz, 0);
}
}
void sum(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
size_t sz = A.size();
for (int i = 0; i < sz; ++i)
for (int j = 0; j < sz; ++j)
ret[i][j] = (A[i][j] + B[i][j]);
}
void subtract(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
size_t sz = A.size();
// ret.clear();
// ret.resize(sz);
for (int i = 0; i < sz; ++i)
for (int j = 0; j < sz; ++j)
ret[i][j] = (A[i][j] - B[i][j]);
}
void strassenHelper(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
if (A.size() == 1)
{
ret[0][0] = A[0][0] * B[0][0];
return;
}
size_t sz = A.size();
size_t new_sz = sz / 2;
ret = vector<vector<int>>(sz, vector<int>(sz));
vector<vector<int>> a11(new_sz), a12(new_sz), a21(new_sz), a22(new_sz), b11(new_sz), b12(new_sz), b21(new_sz), b22(new_sz);
for (int i = 0; i < new_sz; ++i)
{
for (int j = 0; j < new_sz; ++j)
{
a11[i].push_back(A[i][j]);
a12[i].push_back(A[i][j + new_sz]);
a21[i].push_back(A[i + new_sz][j]);
a22[i].push_back(A[i + new_sz][j + new_sz]);
b11[i].push_back(B[i][j]);
b12[i].push_back(B[i][j + new_sz]);
b21[i].push_back(B[i + new_sz][j]);
b22[i].push_back(B[i + new_sz][j + new_sz]);
}
}
vector<vector<int>> result1(new_sz, vector<int>(new_sz, 0)), result2(new_sz, vector<int>(new_sz, 0));
// p1
vector<vector<int>> p1(new_sz, vector<int>(new_sz, 0));
sum(a11, a22, result1);
sum(b11, b22, result2);
strassenHelper(result1, result2, p1);
// p2
vector<vector<int>> p2(new_sz, vector<int>(new_sz, 0));
sum(a21, a22, result1);
strassenHelper(result1, b11, p2);
// p3
vector<vector<int>> p3(new_sz, vector<int>(new_sz, 0));
subtract(b12, b22, result2);
strassenHelper(a11, result2, p3);
// p4
vector<vector<int>> p4(new_sz, vector<int>(new_sz, 0));
subtract(b21, b11, result2);
strassenHelper(a22, result2, p4);
// p5
vector<vector<int>> p5(new_sz, vector<int>(new_sz, 0));
sum(a11, a12, result1);
strassenHelper(result1, b22, p5);
// p6
vector<vector<int>> p6(new_sz, vector<int>(new_sz, 0));
subtract(a21, a11, result1);
sum(b11, b12, result2);
strassenHelper(result1, result2, p6);
// p7
vector<vector<int>> p7(new_sz, vector<int>(new_sz, 0));
subtract(a12, a22, result1);
sum(b21, b22, result2);
strassenHelper(result1, result2, p7);
vector<vector<int>> c11(new_sz, vector<int>(new_sz, 0));
vector<vector<int>> c12(new_sz, vector<int>(new_sz, 0));
vector<vector<int>> c21(new_sz, vector<int>(new_sz, 0));
vector<vector<int>> c22(new_sz, vector<int>(new_sz, 0));
sum(p3, p5, c12);
sum(p2, p4, c21);
sum(p1, p4, result1);
sum(result1, p7, result2);
subtract(result2, p5, c11);
sum(p1, p3, result1);
sum(result1, p6, result2);
subtract(result2, p2, c22);
for (int i = 0; i < new_sz; ++i)
{
for (int j = 0; j < new_sz; ++j)
{
ret[i][j] = c11[i][j];
ret[i][j + new_sz] = c12[i][j];
ret[i + new_sz][j] = c21[i][j];
ret[i + new_sz][j + new_sz] = c22[i][j];
}
}
}
void strassen(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
if (A.empty() || B.empty())
throw runtime_error("empty matrices");
if (A[0].size() != B.size())
throw runtime_error("A's col not equal B's row");
zeroPadding(A);
zeroPadding(B);
strassenHelper(A, B, ret);
}
int main(int argc, const char * argv[]) {
vector<vector<int>> A{{1, 2, 0}, {1, 2, 3}, {1, 2, 3}};
vector<vector<int>> B{{1, 0, 1}, {1, 1, 1}, {2, 1, 1}};
vector<vector<int>> ret(2, vector<int>(2));
strassen(A, B, ret);
printMatrix(ret);
return 0;
}
Reference:
https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/