#include<iostream>
#include<vector>
#include<string>
#include<set>
#include<map>
#include<unordered_set>
#include<unordered_map>
#include<algorithm>
#include<xfunctional>
using namespace std;
vector<vector<int>> add(vector<vector<int>> A, vector<vector<int>> B){
int size = A.size();
vector<vector<int>> res(size,vector<int>(size,0));
if (size == 0) return res;
for (int i = 0; i < size; i++){
for (int j = 0; j < size; j++){
res[i][j] = A[i][j] + B[i][j];
}
}
return res;
}
vector<vector<int>> sub(vector<vector<int>> A,vector<vector<int>> B){
int size = A.size();
vector<vector<int>> res(size,vector<int>(size,0));
for (int i = 0; i < size; i++){
for (int j = 0; j < size; j++){
res[i][j] = A[i][j] - B[i][j];
}
}
return res;
}
void multiply(vector<vector<int>> A,vector<vector<int>> B,vector<vector<int>>& C){
int n = A.size();
if (n == 1){
C[0][0] = A[0][0] * B[0][0];
return;
}
int n1 = n / 2;
vector<vector<int>> A11(n1,vector<int>(n1,0));
vector<vector<int>> B22(n1,vector<int>(n1,0));
vector<vector<int>> B11(n1,vector<int>(n1,0));
vector<vector<int>> A22(n1,vector<int>(n1,0));
vector<vector<int>> s1(n1,vector<int>(n1,0));
vector<vector<int>> s2(n1,vector<int>(n1,0));
vector<vector<int>> s3(n1,vector<int>(n1,0));
vector<vector<int>> s4(n1,vector<int>(n1,0));
vector<vector<int>> s5(n1,vector<int>(n1,0));
vector<vector<int>> s6(n1,vector<int>(n1,0));
vector<vector<int>> s7(n1,vector<int>(n1,0));
vector<vector<int>> s8(n1,vector<int>(n1,0));
vector<vector<int>> s9(n1,vector<int>(n1,0));
vector<vector<int>> s10(n1,vector<int>(n1,0));
for (int i = 0; i < n1; i++){
for (int j = n1; j < n; j++){
s1[i][j-n1] = B[i][j] - B[i+n1][j];
s7[i][j-n1] = A[i][j] - A[i+n1][j];
}
}
for (int i = 0; i < n1; i++){
for (int j = 0; j < n1; j++){
s2[i][j] = A[i][j] + A[i][j+n1];
s10[i][j] = B[i][j] + B[i][j+n1];
A11[i][j] = A[i][j];
B11[i][j] = B[i][j];
A22[i][j] = A[i+n1][j+n1];
B22[i][j] = B[i+n1][j+n1];
}
}
for (int i = n1; i < n; i++){
for (int j = 0; j < n1; j++){
s3[i-n1][j] = A[i][j] + A[i][j+n1];
s8[i-n1][j] = B[i][j] + B[i][j+n1];
}
}
for (int i = 0; i < n1; i++){
for (int j = 0; j < n1; j++){
s4[i][j] = B[i+n1][j] - B[i][j];
s9[i][j] = A[i][j] - A[i+n1][j];
}
}
for (int i = 0; i < n1; i++){
for (int j = 0; j < n1; j++){
s5[i][j] = A[i][j] + A[i+n1][j+n1];
s6[i][j] = B[i][j] + B[i+n1][j+n1];
}
}
vector<vector<int>> p1(n1,vector<int>(n1,0));
vector<vector<int>> p2(n1,vector<int>(n1,0));
vector<vector<int>> p3(n1,vector<int>(n1,0));
vector<vector<int>> p4(n1,vector<int>(n1,0));
vector<vector<int>> p5(n1,vector<int>(n1,0));
vector<vector<int>> p6(n1,vector<int>(n1,0));
vector<vector<int>> p7(n1,vector<int>(n1,0));
multiply(A11,s1,p1);
multiply(s2,B22,p2);
multiply(s3,B11,p3);
multiply(A22,s4,p4);
multiply(s5,s6,p5);
multiply(s7,s8,p6);
multiply(s9,s10,p7);
vector<vector<int>> c11;
vector<vector<int>> c12;
vector<vector<int>> c21;
vector<vector<int>> c22;
c11 = add(sub(add(p5, p4), p2), p6);
c12 = add(p1,p2);
c21 = add(p3,p4);
c22 = sub(sub(add(p5, p1), p3), p7);
for (int i = 0; i < n1; i++){
for (int j = 0; j < n1; j++){
C[i][j] = c11[i][j];
C[i][j+n1] = c12[i][j];
C[i + n1][j] = c21[i][j];
C[i + n1][j + n1] = c22[i][j];
}
}
}
int main(){
int n;
cin >> n;
vector<vector<int>> A(n,vector<int>(n,0));
vector<vector<int>> B(n,vector<int>(n,0));
for (int i = 0; i < n; i++){
for (int j = 0; j < n; j++) cin >> A[i][j];
}
for (int i = 0; i < n; i++){
for (int j = 0; j < n; j++) cin >> B[i][j];
}
vector<vector<int>> result(n,vector<int>(n,0));
multiply(A,B,result);
int size = result.size();
for (int i = 0; i < size; i++){
for (int j = 0; j < size; j++) cout << result[i][j] << " ";
cout << endl;
}
system("pause");
return 0;
}
算法导论(Strasssen's algorithm)
最新推荐文章于 2020-12-21 14:08:53 发布