#include <iostream>
using namespace std;
void ADD(int **MatrixA, int **MatrixB, int **MatrixResult, int MatrixSize){
for(int i = 0; i < MatrixSize; i++){
for(int j = 0; j < MatrixSize; j++){
MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];
}
}
}
void MUL(int **MatrixA, int **MatrixB, int **MatrixResult, int MatrixSize){
for(int i = 0; i < MatrixSize; i++){
for(int j = 0; j < MatrixSize; j++){
MatrixResult[i][j] = 0;
for(int k = 0; k < MatrixSize; k++){
MatrixResult[i][j] += MatrixA[i][k] * MatrixB[k][j];
}
}
}
}
void SUB(int **MatrixA, int **MatrixB, int **MatrixResult, int MatrixSize){
for(int i = 0; i < MatrixSize; i++){
for(int j = 0; j < MatrixSize; j++){
MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];
}
}
}
void Strassen(int N, int **MatrixA,int **MatrixB, int **MatrixC){
int HalfSize = N/2;
int newSize = N/2;
if(N <= 64){
MUL(MatrixA, MatrixB, MatrixC, N);
}else{
int **A11 = new int *[newSize];
int **A12 = new int *[newSize];
int **A21 = new int *[newSize];
int **A22 = new int *[newSize];
int **B11 = new int *[newSize];
int **B12 = new int *[newSize];
int **B21 = new int *[newSize];
int **B22 = new int *[newSize];
int **C11 = new int *[newSize];
int **C12 = new int *[newSize];
int **C21 = new int *[newSize];
int **C22 = new int *[newSize];
int **M1 = new int *[newSize];
int **M2 = new int *[newSize];
int **M3 = new int *[newSize];
int **M4 = new int *[newSize];
int **M5 = new int *[newSize];
int **M6 = new int *[newSize];
int **M7 = new int *[newSize];
int **AResult = new int *[newSize];
int **BResult = new int *[newSize];
int newLength = newSize;
for(int i = 0; i < newSize; i++){
A11[i] = new int[newLength];
A12[i] = new int[newLength];
A21[i] = new int[newLength];
A22[i] = new int[newLength];
B11[i] = new int[newLength];
B12[i] = new int[newLength];
B21[i] = new int[newLength];
B22[i] = new int[newLength];
C11[i] = new int[newLength];
C12[i] = new int[newLength];
C21[i] = new int[newLength];
C22[i] = new int[newLength];
M1[i] = new int[newLength];
M2[i] = new int[newLength];
M3[i] = new int[newLength];
M4[i] = new int[newLength];
M5[i] = new int[newLength];
M6[i] = new int[newLength];
M7[i] = new int[newLength];
AResult[i] = new int[newLength];
BResult[i] = new int[newLength];
}
for(int i = 0; i < N/2; i++){
for(int j = 0; j < N/2; j++){
A11[i][j] = MatrixA[i][j];
A12[i][j] = MatrixA[i][j+N/2];
A21[i][j] = MatrixA[i+N/2][j];
A22[i][j] = MatrixA[i+N/2][j+N/2];
B11[i][j] = MatrixB[i][j];
B12[i][j] = MatrixB[i][j+N/2];
B21[i][j] = MatrixB[i+N/2][j];
B22[i][j] = MatrixB[i+N/2][j+N/2];
}
}
ADD(A11, A22, AResult, HalfSize);
ADD(B11, B22, BResult, HalfSize);
Strassen(HalfSize, AResult, BResult, M1);
ADD(A21, A22, AResult, HalfSize);
Strassen(HalfSize, AResult, B11, M2);
SUB(B12, B22, BResult, HalfSize);
Strassen(HalfSize, A11, BResult, M3);
SUB(B21, B11, BResult, HalfSize);
Strassen(HalfSize, A22, BResult, M4);
ADD(A11, A12, AResult, HalfSize);
Strassen(HalfSize, AResult, B22, M5);
SUB(A21, A11, AResult, HalfSize);
ADD(B11, B12, BResult, HalfSize);
Strassen(HalfSize, AResult, BResult, M6);
SUB(A12, A22, AResult, HalfSize);
ADD(B21, B22, BResult, HalfSize);
Strassen(HalfSize, AResult, BResult, M7);
ADD(M1, M4, AResult, HalfSize);
SUB(AResult, M5, BResult, HalfSize);
ADD(BResult, M7, C11, HalfSize);
ADD(M3, M5, C12, HalfSize);
ADD(M2, M4, C21, HalfSize);
ADD(M1, M3, AResult, HalfSize);
SUB(AResult, M2, BResult, HalfSize);
ADD(BResult, M6, C22, HalfSize);
for(int i = 0; i < N/2; i++){
for(int j = 0; j < N/2; j++){
MatrixC[i][j] = C11[i][j];
MatrixC[i][j+N/2] = C12[i][j];
MatrixC[i+N/2][j] = C21[i][j];
MatrixC[i+N/2][j+N/2] = C22[i][j];
}
}
for(int i = 0; i < newLength; i++){
delete[] A11[i];
delete[] A12[i];
delete[] A21[i];
delete[] A22[i];
delete[] B11[i];
delete[] B12[i];
delete[] B21[i];
delete[] B22[i];
delete[] C11[i];
delete[] C12[i];
delete[] C21[i];
delete[] C22[i];
delete[] M1[i];
delete[] M2[i];
delete[] M3[i];
delete[] M4[i];
delete[] M5[i];
delete[] M6[i];
delete[] M7[i];
delete[] AResult[i];
delete[] BResult[i];
}
delete[] A11;
delete[] A12;
delete[] A21;
delete[] A22;
delete[] B11;
delete[] B12;
delete[] B21;
delete[] B22;
delete[] C11;
delete[] C12;
delete[] C21;
delete[] C22;
delete[] M1;
delete[] M2;
delete[] M3;
delete[] M4;
delete[] M5;
delete[] M6;
delete[] M7;
delete[] AResult;
delete[] BResult;
}
}
void PrintMatrix(int **MatrixA, int MatrixSize){
cout<<endl;
for(int row = 0; row < MatrixSize; row++){
for(int column = 0; column < MatrixSize; column++){
cout<<MatrixA[row][column]<<'\t';
if((column+1) % MatrixSize == 0)
cout<<endl;
}
}
}
void FillMatrix(int **MatrixA, int MatrixSize, int n){
for(int i = 0; i < MatrixSize; i++){
for(int j = 0; j < MatrixSize; j++){
MatrixA[i][j] = n;
}
}
}
int main() {
int MatrixSize = 100;
int **MatrixA;
int **MatrixB;
int **MatrixC;
MatrixA = new int *[MatrixSize];
MatrixB = new int *[MatrixSize];
MatrixC = new int *[MatrixSize];
for(int i = 0; i < MatrixSize; i++){
MatrixA[i] = new int [MatrixSize];
MatrixB[i] = new int [MatrixSize];
MatrixC[i] = new int [MatrixSize];
}
FillMatrix(MatrixA, MatrixSize, 1);
FillMatrix(MatrixB, MatrixSize, 2);
Strassen(MatrixSize, MatrixA, MatrixB, MatrixC);
PrintMatrix(MatrixC, MatrixSize);
return 0;
}