矩阵的简单操作代码实现
转置定义(transpose):
转置是矩阵的重要操作之一。矩阵的转置是以对角线为轴的镜像,这条从左上角到右下角的对角线称为主对角线(main diagonal)。下图显示这个操作。将矩阵A转置表示为A^T,定义如下:
(参考 《深度学习》–[美]伊恩·古德费洛,[加]约书亚·本吉奥,[加]亚伦·库维尔 第2章)
向量可以看作只有一列的矩阵。对应地,向量的转置可以看作只有一行的矩阵。有时,我们通过将向量元素作为行矩阵写在文本行中,然后使用转置操作将其变为标准的列向量,来定义一个向量 (说明矩阵转置不主要要求为方阵),下图演示了转置的过程:
代码实现:
用java的数组进行简单实现,其实就是数组的操作:
package cn.qulei.matrix;
/**
* 实现矩阵转置
*
* @author QuLei
*/
public class MatrixTranspose {
public static void main(String[] args) {
int[][] matrix = new int[3][4];
int k = 1;
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++) {
matrix[i][j] = k++;
}
}
int[][] transposedMatrix = transpose(matrix);
System.out.println("输出原矩阵:");
print(matrix);
System.out.println("-----------------");
System.out.println("转置后的矩阵:");
print(transposedMatrix);
}
/**
* 转置矩阵操作
*
* @param matrix 待转置矩阵
* @return 转置后的矩阵
*/
private static int[][] transpose(int[][] matrix) {
int[][] transposedMatrix = new int[matrix[0].length][matrix.length];
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++) {
transposedMatrix[j][i] = matrix[i][j];
}
}
return transposedMatrix;
}
/**
*封装打印数组方法
*
* @param matrix 待打印数组
*/
private static void print(int[][] matrix) {
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++) {
System.out.print(matrix[i][j]);
System.out.print("\t");
}
System.out.println();
}
}
}
测试用例结果:
输出原矩阵:
1 2 3 4
5 6 7 8
9 10 11 12
-----------------
转置后的矩阵:
1 5 9
2 6 10
3 7 11
4 8 12
Process finished with exit code 0
矩阵的乘积
简单定义:
参考:《深度学习的数学》–[日]涌井良幸,[日]涌井贞美 第2-5节
代码实现
import org.junit.Test;
/**
* 测试矩阵相乘
*
* @author QuLei
*/
public class TestMatrix {
@Test
public void test() {
int[][] arr1 = {{2, 7}, {1, 8}};
int[][] arr2 = {{2, 8}, {1, 3}};
int[][] milt = milt(arr1, arr2);
for (int i = 0; i < milt.length; i++) {
for (int j = 0; j < milt[0].length; j++) {
System.out.print(milt[i][j] + "\t");
}
System.out.println();
}
}
/**
* 计算两矩阵相乘
*
* @param arr1
* @param arr2
* @return
*/
int[][] milt(int[][] arr1, int[][] arr2) {
//这里为了实现数学上的问题,简化了对矩阵合法性的判断
if (arr1[0].length != arr2.length) {
throw new RuntimeException("不满足矩阵乘法基本要求!!!");
}
int row = arr1.length;
int col = arr2[0].length;
int[][] result = new int[row][col];
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
for (int k = 0; k < arr2.length; k++) {
result[i][j] += arr1[i][k] * arr2[k][j];
}
}
}
return result;
}
}
测试结果:
11 37
10 32