Deeplearning4j的数据是由一个叫做DataSet的对象传入网络进行训练的,DataSet由四个主要元素组成,Features,Labels,FeaturesMask,LabelsMask,这四个元素都是INDArray,即是N维矩阵或者叫做N维张量。一般来说是2-4维矩阵,分别对应全联接网络、RNN网络、CNN网络的输入。
四个元素简单介绍如下:
- Features 特征,特征可以是N维矩阵,以RNN举例,RNN的输入矩阵各维度的维数是[MiniBatch,FeaturesLength,TimeSeqLength],其中第二个维度就是特征列的个数,单个值[x,y,z]的含义就是MiniBatch的某一个批次x,时间序列TimeSeq的某一个时间点z,某一个Feature特征y的值。
- Labels 标签,标签的维度需要和特征相对应,还是以RNN为例,标签的维度的维数就是[MiniBatch,LabelLength,TimeSeqLength],其中,如果网络是一个分类器网络的话,LabelLength是对应标签的独热处理,即LabelLength相当于分类数classes;如果网络是一个回归函数的话,那么LabelLength就是对应输出的几个回归目标值y的个数(一般是一个)。即是说,FeaturesLength相当于网络的InputSize;LabelLength相当于网络的OutputSize。输入输出的宽度。
- FeaturesMask以及LabelsMask 特征掩模和标签掩模,如果需要掩盖某些数据的输入输出,即我们需要扔掉一些数据的输入或者输出,比方说RNN序列输出我只需要输出最后一个,或者输入我只需要前三个,那么这两个元素就有用了。以RNN为例,如果我的Label每个TimeSeq只输出最后一个时间点的值,那么LabelsMask就可以这么写,labelsMask的维度是[MiniBatch,TimeSeqLength],比方说是[x,y]当且仅当y = TimeSeqLength - 1 的时候[x,y] = 1,其余[x,y] = 0。这样就写好了一个输出的掩模。
关于RNN掩模的具体介绍可以看官网:
Dataset的初始化方法源码:
/**
* Creates a dataset with the specified input matrix and labels
*
* @param first the feature matrix
* @param second the labels (these should be binarized label matrices such that the specified label
* has a value of 1 in the desired column with the label)
*/
public DataSet(INDArray first, INDArray second) {
this(first, second, null, null);
}
/**Create a dataset with the specified input INDArray and labels (output) INDArray, plus (optionally) mask arrays
* for the features and labels
* @param features Features (input)
* @param labels Labels (output)
* @param featuresMask Mask array for features, may be null
* @param labelsMask Mask array for labels, may be null
*/
public DataSet(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) {
this.features = features;
this.labels = labels;
this.featuresMask = featuresMask;
this.labelsMask = labelsMask;
// we want this dataset to be fully committed to device
Nd4j.getExecutioner().commit();
}
可以用getRange函数来截取一部分数据。
DataSet可以用merge函数竖向拼接。
可以用load函数来从流(或将文件转化为流)中读取数据。
可以用save函数来将DataSet存成流或者文件。Save和load函数会规定好用一个Byte的数据表示读取的一些属性,这样load的时候就能正确解析。