题目链接:点击打开链接
题目大意:给出三个n*n矩阵的矩阵a,b,c问a*b是否等于c,等于输出YES,否则输出NO
n的最大值是500,计算矩阵乘法的话需要O(n^3)的复杂度,很明显超时。
随机出一列k,计算a*(b*k) 和c*k,计算出一列的值,这样的如果a*b==c那么a*(b*k) 和c*k也一定会相等的,因为是随机的数,所以可以多测试几次。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
#define LL __int64
LL a[510][510] , b[510][510] , c[510][510] ;
LL k1[510] , k2[510] ;
void read(int n,LL s[][510]) {
int i , j ;
for(i = 1 ; i <= n ; i++)
for(j = 1 ; j <= n ; j++)
scanf("%I64d", &s[i][j]) ;
}
void solve(int n,LL s[][510],LL k[]) {
int p[510] , i , j ;
for(i = 1 ; i <= n ; i++)
p[i] = k[i] ;
for(i = 1 ; i <= n ; i++) {
for(j = 1 , k[i] = 0 ; j <= n ; j++)
k[i] += s[i][j]*p[j] ;
}
}
int main() {
int n , i , num = 10 , cnt = 0 ;
scanf("%d", &n) ;
read(n,a) ;
read(n,b) ;
read(n,c) ;
while( num-- ) {
for(i = 1 ; i <= n ; i++)
k1[i] = k2[i] = rand()%1000 ;
solve(n,b,k1) ;
solve(n,a,k1) ;
solve(n,c,k2) ;
for(i = 1; i <= n ; i++)
if( k1[i] != k2[i] ) break ;
if( i > n ) cnt++ ;
}
if( cnt > 5 )
printf("YES\n") ;
else
printf("NO\n") ;
return 0 ;
}