题目地址:
https://leetcode.com/problems/sparse-matrix-multiplication/
给定两个稀疏矩阵 A A A和 B B B,求它们的乘积。题目保证 A A A和 B B B可乘。
法1:遍历 A A A的第 i i i行和 B B B的第 j j j列的时候,用两个指针分别指向它们,然后分别移动,先找到第一个非零数,然后看偏移量是否相等,若相等则累加到乘积 C C C里去,否则直接将偏移量小的指针赋值为偏移量大的指针的值,再重复做同样的事情,直到该行(列)遍历完,就得到了 C [ i ] [ j ] C[i][j] C[i][j]。如此这样遍历两个矩阵的所有行和列即可。代码如下:
public class Solution {
public int[][] multiply(int[][] A, int[][] B) {
if (A == null || A.length == 0 || A[0].length == 0 || B == null || B.length == 0 || B[0].length == 0) {
return new int[0][0];
}
int p = A.length, q = A[0].length, r = B[0].length;
int[][] res = new int[p][r];
// 计算res[i][j]
for (int i = 0; i < p; i++) {
for (int j = 0; j < r; j++) {
int idx1 = 0, idx2 = 0;
while (idx1 < q && idx2 < q) {
// 找到A第i行的下一个非零数
while (idx1 < q && A[i][idx1] == 0) {
idx1++;
}
// 如果出界了则直接退出循环
if (idx1 == q) {
break;
}
// 找到B第j列的下一个非零数
while (idx2 < q && B[idx2][j] == 0) {
idx2++;
}
// 如果出界了则直接退出循环
if (idx2 == q) {
break;
}
// 如果两个偏移量相等,说明A[i][idx1]和B[idx1][j]都非零,
// 则乘起来并累加到res[i][j]上去,并将两个指针都向后移动一位;
// 如果偏移量不等,譬如idx1 < idx2,那么我们起码知道B[idx1, ..., idx2 - 1][j]都等于0,
// 这时应该直接将idx1赋值为idx2,否则将idx2赋值为idx1
if (idx1 == idx2) {
res[i][j] += A[i][idx1] * B[idx1][j];
idx1++;
idx2++;
} else {
int max = Math.max(idx1, idx2);
idx1 = idx2 = max;
}
}
}
}
return res;
}
}
时间复杂度 O ( p q r ) O(pqr) O(pqr), p p p和 q q q分别是 A A A的行数和列数, r r r是 B B B的列数。空间 O ( 1 ) O(1) O(1)(不计返回结果的空间)。由于是稀疏矩阵,所以做乘法的机会是比较少的,实际运行速度会比朴素乘法快很多。
法2:我们可以发现, A [ i ] [ j ] A[i][j] A[i][j]对最终乘积 C C C的贡献只体现在 A [ i ] [ j ] A[i][j] A[i][j]和 B B B的第 j j j行各个数字相乘上。所以可以直接遍历 A A A的非零项 A [ i ] [ j ] A[i][j] A[i][j],然后再遍历 B B B的第 j j j行,把 A [ i ] [ j ] A[i][j] A[i][j]和 B [ j ] [ k ] B[j][k] B[j][k]的乘积累加到 C [ i ] [ k ] C[i][k] C[i][k]上即可。代码如下:
public class Solution {
public int[][] multiply(int[][] A, int[][] B) {
if (A == null || A.length == 0 || A[0].length == 0 || B == null || B.length == 0 || B[0].length == 0) {
return new int[0][0];
}
int p = A.length, q = A[0].length, r = B[0].length;
int[][] res = new int[p][r];
// 外面两层循环是遍历A
for (int i = 0; i < p; i++) {
for (int j = 0; j < q; j++) {
if (A[i][j] != 0) {
// 再遍历B的第j行
for (int k = 0; k < r; k++) {
if (B[j][k] != 0) {
// 累加
res[i][k] += A[i][j] * B[j][k];
}
}
}
}
}
return res;
}
}
时间复杂度 O ( x r ) O(xr) O(xr), x x x为 A A A的非零项的个数。空间 O ( 1 ) O(1) O(1)。