目录
一、矩阵乘法
1.数学基础
要想利用代码实现矩阵的乘法运算,首先我们需要知道有关矩阵相乘的一些基本概念。由于大一下刚好学习了线性代数,所以这里大概地回顾一下相关知识。
1.1矩阵相乘的前提
并不是所有的矩阵都可以相乘,矩阵相乘是有前提条件的,只有当第一个矩阵的列数等于第二个矩阵的行数时,这两个矩阵才可以相乘,否则就无法相乘。假设第一个矩阵A是m行n列的矩阵(即m*n),第二个矩阵B是p行q列的矩阵(即p*q),如果n=p则矩阵A与矩阵B可以相乘,如果n≠p那么矩阵A与矩阵B就不能相乘。
也正是因为这个前提,矩阵相乘是不满足乘法交换律的,有时矩阵A*B可以,但交换后矩阵B*A却不行。
1.2矩阵相乘的规则
为更清晰地理解矩阵相乘的规则,我们作出如下假设:
- 第一个矩阵A为m*n矩阵,其元素用a表示
- 第二个矩阵B为n*p矩阵,其元素用b表示
- 相乘得到的结果矩阵为C,其元素用c表示
显然,矩阵A与矩阵B满足第一个矩阵的列数等于第二个矩阵的行数(n=n),所以可以相乘。接下来,我们将矩阵A的第一行元素与矩阵B的第一列元素对应相乘再相加,得到的结果即为矩阵C第一行第一列的元素,同理,矩阵A的第一行元素与矩阵B的第二列元素对应相乘再相加,得到的结果就是矩阵C第一行第二列的元素……
用图形表示如下:
根据以上相乘规则,很容易可以推断出:
- 结果矩阵C的行数=第一个矩阵A的行数
- 结果矩阵C的列数=第二个矩阵B的列数
借用某站宋浩老师的话来说就是“中间相等,取两头”,再将以上规则总结为数学公式,如下:
2.代码实现
弄清楚相关数学基础后,我们来到代码实现环节。为满足矩阵相乘的前提条件,我们第一步先进行Dimension check:
public static int[][] multiplication(int[][] paraFirstMatrix, int[][] paraSecondMatrix) {
int m = paraFirstMatrix.length;
int n = paraFirstMatrix[0].length;
int p = paraSecondMatrix[0].length;
// Step 1. Dimension check.
if (paraSecondMatrix.length != n) {
System.out.println("The two matrices cannot be multiplied.");
return null;
} // Of if
这里用到了我们昨天提到过的数组名.length,根据昨天的相关分析我们可以很容易地知道:
- paraFirstMatrix.length表示第一个矩阵的行数
- paraFirstMatrix[0].length表示第一个矩阵的列数
- paraSecondMatrix.length表示第二个矩阵的行数
- paraSecondMatrix[0].length表示第二个矩阵的列数
然后,为得到结果矩阵,我们需要用到三层for循环,根据上述矩阵乘法规则,编写代码如下:
// Step 2. The loop.
int[][] resultMatrix = new int[m][p];
for (int i = 0; i < m; i++) {
for (int j = 0; j < p; j++) {
for (int k = 0; k < n; k++) {
resultMatrix[i][j] += paraFirstMatrix[i][k] * paraSecondMatrix[k][j];
} // Of for k
} // Of for j
} // Of for i
return resultMatrix;
}// Of multiplication
这里的三层for循环可以算是这次代码的核心,也是我们需要重点关注的地方。首先,第一层循环for(int k = 0;k < n;k++),k逐一取值实现了第一个矩阵的某一行与第二个矩阵的某一列对应相乘再相加;然后第二层循环for(int j = 0;j < p;j++),j的逐一取值其实也就是第二个矩阵的逐列取得;同理可得,最外层的循环for(int i = 0;i < m;i++),i的逐一取值则实现了第一个矩阵的逐行取得。至此,通过该三层循环就实现了矩阵的乘法运算。
最后,我们照例进行数据测试,如下:
/**
*********************
* Unit test for respective method.
*********************
*/
public static void matrixMultiplicationTest() {
int[][] tempFirstMatrix = new int[2][3];
for (int i = 0; i < tempFirstMatrix.length; i++) {
for (int j = 0; j < tempFirstMatrix[0].length; j++) {
tempFirstMatrix[i][j] = i + j;
} // Of for j
} // Of for i
System.out.println("The first matrix is: \r\n" + Arrays.deepToString(tempFirstMatrix));
int[][] tempSecondMatrix = new int[3][2];
for (int i = 0; i < tempSecondMatrix.length; i++) {
for (int j = 0; j < tempSecondMatrix[0].length; j++) {
tempSecondMatrix[i][j] = i * 10 + j;
} // Of for j
} // Of for i
System.out.println("The second matrix is: \r\n" + Arrays.deepToString(tempSecondMatrix));
int[][] tempThirdMatrix = multiplication(tempFirstMatrix, tempSecondMatrix);
System.out.println("The third matrix is: \r\n" + Arrays.deepToString(tempThirdMatrix));
二、MatrixMultiplication.java
完整的程序代码:
package basic;
import java.util.Arrays;
/**
* This is the eighth code. Names and comments should follow my style strictly.
*
* @author Xin Lin 3101540094@qq.com.
*/
public class MatrixMultiplication {
/**
*********************
* The entrance of the program.
*
* @param args Not used now.
*********************
*/
public static void main(String args[]) {
matrixMultiplicationTest();
}// Of main
/**
*********************
* Matrix multiplication. The columns of the first matrix should be equal to the
* rows of the second one.
*
* @param paraFirstMatrix The first matrix.
* @param paraSecondMatrix The second matrix.
* @return The result matrix.
*********************
*/
public static int[][] multiplication(int[][] paraFirstMatrix, int[][] paraSecondMatrix) {
int m = paraFirstMatrix.length;
int n = paraFirstMatrix[0].length;
int p = paraSecondMatrix[0].length;
// Step 1. Dimension check.
if (paraSecondMatrix.length != n) {
System.out.println("The two matrices cannot be multiplied.");
return null;
} // Of if
// Step 2. The loop.
int[][] resultMatrix = new int[m][p];
for (int i = 0; i < m; i++) {
for (int j = 0; j < p; j++) {
for (int k = 0; k < n; k++) {
resultMatrix[i][j] += paraFirstMatrix[i][k] * paraSecondMatrix[k][j];
} // Of for k
} // Of for j
} // Of for i
return resultMatrix;
}// Of multiplication
/**
*********************
* Unit test for respective method.
*********************
*/
public static void matrixMultiplicationTest() {
int[][] tempFirstMatrix = new int[2][3];
for (int i = 0; i < tempFirstMatrix.length; i++) {
for (int j = 0; j < tempFirstMatrix[0].length; j++) {
tempFirstMatrix[i][j] = i + j;
} // Of for j
} // Of for i
System.out.println("The first matrix is: \r\n" + Arrays.deepToString(tempFirstMatrix));
int[][] tempSecondMatrix = new int[3][2];
for (int i = 0; i < tempSecondMatrix.length; i++) {
for (int j = 0; j < tempSecondMatrix[0].length; j++) {
tempSecondMatrix[i][j] = i * 10 + j;
} // Of for j
} // Of for i
System.out.println("The second matrix is: \r\n" + Arrays.deepToString(tempSecondMatrix));
int[][] tempThirdMatrix = multiplication(tempFirstMatrix, tempSecondMatrix);
System.out.println("The third matrix is: \r\n" + Arrays.deepToString(tempThirdMatrix));
System.out.println("Trying to multiply the first matrix with itself.\r\n");
tempThirdMatrix = multiplication(tempFirstMatrix, tempFirstMatrix);
System.out.println("The result matrix is: \r\n" + Arrays.deepToString(tempThirdMatrix));
}// Of matrixMultiplicationTest
}// Of class MatrixMultiplication
运行结果:
注意到,这里我们进行了tempFirstMatrix的自乘运算,但得到的结果是null,也就是说tempFirstMatrix不能进行自乘。这是因为,tempFirstMatrix是2*3的矩阵,我们假设有两个相同的2*3的矩阵,那么我们可以知道第一个矩阵的列数为3,第二个矩阵的行数为2,二者不等,所以它们不能相乘,所以这里的tempFirstMatrix不能进行自乘运算。
总结
昨天学习了java中的矩阵相加,今天学习了java中的矩阵相乘,这两天的学习中,可能因为上学期刚学完线代,所以对于相关数学基础比较得心应手,但等到了代码实现环节就显得不够熟练了,其实,这主要是源于利用代码模拟数学问题的能力不够,不知道怎么才能将数学问题与程序代码联系起来,所以我们要学会将遇到的问题进行拆分化解,一步一步与代码关联,再一环扣一环,实现模拟。在今天查阅资料的过程中,我发现其实除了构造三层for循环来实现矩阵相乘,java中还有一些现成库也可以用于矩阵相乘,比如EJML库、Colt库、LA4J库等等。