DL4J基本操作
导入依赖
<nd4j.version>1.0.0-beta2</nd4j.version>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>${nd4j.version}</version>
</dependency>
1. 创建矩阵
/*
构造一个3行5列的全0 ndarray
*/
System.out.println("构造一个3行5列的全0 ndarray");
INDArray zeros = Nd4j.zeros(3, 5);
System.out.println(zeros);
/*
构造一个3行5列的全1 ndarray
*/
System.out.println("构造一个3行5列的全1 ndarray");
INDArray ones = Nd4j.ones(3, 5);
System.out.println(ones);
/*
构造一个3行5列,数组元素均为随机产生的ndarray
*/
System.out.println("构造一个3行5列,数组元素均为随机产生的ndarray");
INDArray rands = Nd4j.rand(3, 5);
System.out.println(rands);
/*
构造一个3行5列,数组元素服从高斯分布(平均值为0,标准差为1)的ndarray
*/
System.out.println("构造一个3行5列,数组元素服从高斯分布(平均值为0,标准差为1)的ndarray");
INDArray randns = Nd4j.randn(3, 5);
System.out.println(randns);
/*
给一个一维数据,根据shape创造ndarray
*/
System.out.println("给一个一维数据,根据shape创造ndarray");
System.out.println("创建一个值全是2,一行四列的的ndarray");
INDArray array1 = Nd4j.create(new float[]{2, 2, 2, 2}, new int[]{1, 4});
System.out.println(array1);
System.out.println("创建一个值全是2,2行2列的的ndarray");
INDArray array2 = Nd4j.create(new float[]{2, 2, 2, 2}, new int[]{2, 2});
System.out.println(array2);
2. 矩阵元素读取
System.out.println("把一维数组转换成2行6列的 ndarray");
INDArray nd = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, new int[]{2, 6});
System.out.println("打印原有数组");
System.out.println(nd);
/*
获取指定索引的值
*/
System.out.println("获取数组下标为0, 3的值");
double value = nd.getDouble(0, 3);
System.out.println(value);
/*
修改指定索引的值
*/
System.out.println("修改数组下标为0, 3的值");
//scalar 标量
nd.putScalar(0, 3, 100);
System.out.println(nd);
/*
使用索引迭代器遍历ndarray,使用c order
*/
System.out.println("使用索引迭代器遍历ndarray");
NdIndexIterator iter = new NdIndexIterator(2, 6);
while (iter.hasNext()) {
long[] nextIndex = iter.next();
double nextVal = nd.getDouble(nextIndex);
System.out.println(nextVal);
}
3. 矩阵行元素读取
INDArray nd = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, new int[]{2, 6});
System.out.println("原始数组");
System.out.println(nd);
/*
获取一行
*/
System.out.println("获取数组中的一行:第0行");
INDArray singleRow = nd.getRow(0);
System.out.println(singleRow);
/*
获取多行
*/
System.out.println("获取数组中的多行:0行和1行");
INDArray multiRows = nd.getRows(0, 1);
System.out.println(multiRows);
/*
替换其中的一行
*/
System.out.println("替换原有数组中的一行:把第0行替换掉");
INDArray replaceRow = Nd4j.create(new float[]{1, 3, 5, 7, 9, 11});
nd.putRow(0, replaceRow);
System.out.println(nd);
4. 矩阵运算
// 1x2的行向量
INDArray nd = Nd4j.create(new float[]{1,2},new int[]{1, 2});
// 2x1的列向量
INDArray nd2 = Nd4j.create(new float[]{3,4},new int[]{2, 1}); //vector as column
// 创造两个2x2的矩阵
INDArray nd3 = Nd4j.create(new float[]{1,3,2,4},new int[]{2,2}); //elements arranged column major
INDArray nd4 = Nd4j.create(new float[]{3,4,5,6},new int[]{2, 2});
//打印
System.out.println("打印1x2的行向量");
System.out.println(nd);
System.out.println("打印2x1的列向量");
System.out.println(nd2);
System.out.println("打印创造2x2的矩阵");
System.out.println(nd3);
System.out.println("打印创造2x2的矩阵");
System.out.println(nd4);
System.out.println("---------------");
//1x2 and 2x1 -> 1x1
System.out.println("打印1x2 and 2x1 -> 1x1");
INDArray ndv = nd.mmul(nd2);
System.out.println(ndv + ", shape = " + Arrays.toString(ndv.shape()));
//1x2 and 2x2 -> 1x2
System.out.println("打印1x2 and 2x2 -> 1x2");
ndv = nd.mmul(nd4);
System.out.println(ndv + ", shape = " + Arrays.toString(ndv.shape()));
//2x2 and 2x2 -> 2x2
System.out.println("打印2x2 and 2x2 -> 2x2");
ndv = nd3.mmul(nd4);
System.out.println(ndv + ", shape = " + Arrays.toString(ndv.shape()));