给出三个N*N的矩阵A, B, C,问A * B是否等于C?
Input
第1行,1个数N。(0 <= N <= 500) 第2 - N + 1行:每行N个数,对应矩阵A的元素。(0 <= M[i] <= 16) 第N + 2 - 2N + 1行:每行N个数,对应矩阵B的元素。(0 <= M[i] <= 16) 第2N + 2 - 3N + 1行:每行N个数,对应矩阵C的元素。
Output
如果相等输出Yes,否则输出No。
Input示例
2 1 0 0 1 0 1 1 0 0 1 1 0
Output示例
Yes
矩阵乘法是O(n^3)超时,这题利用了矩阵乘法的结合律用一行随机数将矩阵规模缩小到一维,整体复杂度降到O(n^2)
注意事项
当矩阵A的列数等于矩阵B的行数时,A与B可以相乘。
-
矩阵C的行数等于矩阵A的行数,C的列数等于B的列数。
矩阵乘法基本性质
-
乘法结合律: ( AB) C= A( BC). [2]
-
乘法左分配律:( A+ B) C= AC+ BC [2]
-
乘法右分配律: C( A+ B)= CA+ CB [2]
-
对数乘的结合性 k( AB)=( kA) B= A( kB).
-
转置 ( AB) T= B T A T.
-
矩阵乘法一般不满足交换律 [3] 。
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<vector>
#include<map>
#include <set>
//#include <bits/stdc++.h>
using namespace std;
const int N = 520;
typedef long long LL;
LL a[2][N], b[2][N], c[2][N];
LL d[2][N];
const LL seed1=131, seed2=1789;
const LL mod = 10000;
int main()
{
int n;
scanf("%d", &n);
d[0][0]=seed1,d[1][0]=seed2;
for(int i=1;i<=n;i++)
{
d[0][i]=rand()%mod+2;
d[1][i]=rand()%mod+2;
}
// memset(a,0,sizeof(a));
// memset(b,0,sizeof(b));
// memset(c,0,sizeof(c));
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
LL x;
scanf("%lld", &x);
a[0][j]=(a[0][j]+x*d[0][i]);
a[1][j]=(a[1][j]+x*d[1][i]);
}
}
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
LL x;
scanf("%lld", &x);
b[0][j]=(b[0][j]+x*a[0][i]);
b[1][j]=(b[1][j]+x*a[1][i]);
}
}
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
LL x;
scanf("%lld", &x);
c[0][j]=(c[0][j]+x*d[0][i]);
c[1][j]=(c[1][j]+x*d[1][i]);
}
}
int flag=0;
for(int i=1;i<=n;i++)
{
if(b[0][i]!=c[0][i]||b[1][i]!=c[1][i])
{
flag=1;
break;
}
}
if(flag) puts("No");
else puts("Yes");
return 0;
}
奇葩的是我一开始加了个输入挂竟然水过去了
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <map>
#include <vector>
using namespace std;
typedef long long LL;
typedef vector<LL>a;
typedef vector<a>b;
const LL mod = 1000000007;
namespace IN
{
const int inBufferSize = 1<<25;
char inBuffer[inBufferSize];
char *inHead = NULL, *inTail = NULL;
inline char Getchar()
{
if(inHead == inTail)
inTail=(inHead=inBuffer)+fread(inBuffer, 1, inBufferSize, stdin);
return *inHead++;
}
}
#define getchar() IN::Getchar()
template <typename T>
inline void scan_ud(T &ret)
{
char c = getchar();
ret = 0;
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9')
ret = ret * 10 + (c - '0'), c = getchar();
}
b jie(b m, b n)
{
b z(m.size(),a(n[0].size()));
for(int i=0; i<m.size(); i++)
{
for(int k=0; k<n.size(); k++)
{
for(int j=0; j<n[0].size(); j++)
{
z[i][j]=(z[i][j]+m[i][k]*n[k][j]);
}
}
}
return z;
}
b Pow(b x,LL n,int m)
{
b y(m,a(m));
for(int i=0; i<m; i++)
y[i][i]=1;
while(n>0)
{
if(n&1)
{
y=jie(y,x);
}
x=jie(x,x);
n>>=1;
}
return y;
}
int main()
{
LL n;
scanf("%lld",&n);
b x(n,a(n));
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
scan_ud(x[i][j]);
//scanf("%lld", &);
}
}
b y(n,a(n));
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
scan_ud(y[i][j]);
//scanf("%lld", &y[i][j]);
}
}
b z(n,a(n));
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
scan_ud(z[i][j]);
//scanf("%lld", &z[i][j]);
}
}
y=jie(x,y);
int flag=0;
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
if(y[i][j]!=z[i][j])
{
flag=1;
break;
}
}
if(flag) break;
}
if(flag) puts("No");
else puts("Yes");
return 0;
}