PM3
USTC has recently developed the Parallel Matrix Multiplication Machine – PM3, which is used for very large matrix multiplication.
Given two matrices A and B, where A is an N × P matrix and B is a P × M matrix, PM3 can compute matrix C = AB in O(P(N + P + M)) time. However the developers of PM3 soon discovered a small problem: there is a small chance that PM3 makes a mistake, and whenever a mistake occurs, the resultant matrix C will contain exactly one incorrect element.
The developers come up with a natural remedy. After PM3 gives the matrix C, they check and correct it. They think it is a simple task, because there will be at most one incorrect element.
So you are to write a program to check and correct the result computed by PM3.
Input
The first line of the input three integers N, P and M (0 < N, P, M ≤ 1,000), which indicate the dimensions of A and B. Then follow N lines with P integers each, giving the elements of A in row-major order. After that the elements of B and C are given in the same manner.
Elements of A and B are bounded by 1,000 in absolute values which those of C are bounded by 2,000,000,000.
Output
If C contains no incorrect element, print “Yes”. Otherwise print “No” followed by two more lines, with two integers r and c on the first one, and another integer v on the second one, which indicates the element of C at row r, column c should be corrected to v.
Sample Input
2 3 2
1 2 -1
3 -1 0
-1 0
0 2
1 3
-2 -1
-3 -2
Sample Output
No
1 2
1
题意:给矩阵A、B、C,问A*B是否完全等于C,等于输出C,不等于输出坐标和那个坐标给的值,而且只会有一个地方有错误。
简单暴力直接乘不用long long就可以过,但是有一种简化的方法。
给A(2,3),B(3,2)可以得到一个C(2,2)的矩阵。
写出C11和C12。 ///字母后的数字是下标
C11 = A11*B11 + A12*B21+ A13*B31;
C12 = A11*B12 + A12*B22+ A13*B32;
两式相加可得
C11+C12 = B11(A11+A21) + B12(A12+A22) + B13(A13+A23);
可以看得出来,C的每一列的和等于B的每一列的每个元素与对应的A的每一列的和的乘积。
这样只需要根据C每一列的和就可以判断是哪一列出了问题,然后在出错的这一列进行矩阵乘法找出错误的元素并输出就行了。
CODE
#include"stdio.h"
#include"algorithm"
#include"iostream"
#include"string.h"
#define maxn 1000+10
using namespace std;
int n,p,m;
int a[maxn][maxn];
int b[maxn][maxn];
int c[maxn][maxn];
int A[maxn]; ///a的每列和
int C[maxn]; ///c的每列和
bool check(int r,int l) ///对应的那一列进行矩阵乘法
{
int t = 0;
for(int i = 1;i <= p;i++)
{
t += a[r][i]*b[i][l];
}
if(t != c[r][l])
{
printf("No\n%d %d\n%d\n",r,l,t);
return true;
}
return false;
}
int main(void)
{
while(scanf("%d%d%d",&n,&p,&m) !=EOF)
{
memset(A,0,sizeof A);
memset(B,0,sizeof B);
memset(C,0,sizeof C);
for(int i = 1;i <= n;i++)
for(int j = 1;j <= p;j++)
{
scanf("%d",&a[i][j]);
A[j] += a[i][j]; ///A的每列和
}
for(int i = 1;i <= p;i++)
{
for(int j = 1;j <= m;j++)
{
scanf("%d",&b[i][j]);
}
}
for(int i = 1;i <= n;i++)
{
for(int j = 1;j <= m;j++)
{
scanf("%d",&c[i][j]);
C[j] += c[i][j]; ///C的每列和
}
}
int flag = 1; ///判断是否有错误
for(int i = 1;i <= m;i++) ///枚举B的每一列
{
int sum = 0;
for(int j = 1;j <= p;j++)
sum += b[j][i]*A[j];
if(sum != C[i]) ///第i列出错
{
flag = 0;
for(int j = 1;j <= n;j++)
{
if(check(j,i))
break;
}
}
}
if(flag)
printf("Yes\n");
}
return 0;
}