这个例程比较简单,写这篇博客主要时为了做一些简单的记录,以防止后面遇到浪费不必要的时间。
这个例程包含读入CSV数据,对数据进行归一化处理,然后创建简单的神经网络,训练然后预测。
package org.deeplearning4j.examples.dataExamples; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.util.ClassPathResource; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.SplitTestAndTrain; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author Adam Gibson */ public class CSVExample { private static Logger log = LoggerFactory.getLogger(CSVExample.class); 创建log,便于打印日志 public static void main(String[] args) throws Exception { //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing int numLinesToSkip = 0; 有些文件具有表头,有些没有。即读取文件时需要跳过的行数 String delimiter = ","; 数据之间的分隔符 RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter); 文件读取器 recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); 从磁盘读取文件 //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //label所在的位置,//5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; 分多少类//3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150;数据共有多少条?还是要批处理的数量? //将数据存入迭代器,参数分别为:读取器 批处理的量 label的位置 分多少类 DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); DataSet allData = iterator.next(); 将数据转为DataSet格式 allData.shuffle(); 混洗,打乱数据 //分成训练集和测试集 SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training DataSet trainingData = testAndTrain.getTrain(); 获得训练集 DataSet testData = testAndTrain.getTest(); 获得测试集 System.out.println("allData = "+allData.numExamples()+" train = "+trainingData.numExamples()); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); 对数据进行归一化//Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
normalizer.fit(trainingData); 计算训练集的均值和方差//Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); 对训练集进行归一化
normalizer.transform(testData); 利用训练集的数据对测试集进行归一化
final int numInputs = 4; 输入数据的维度
int outputNum = 3; 分类的个数
int iterations = 1000; 迭代次数
long seed = 6; 随机数
log.info("Build model...."); 配置网络结构
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed) .iterations(iterations) .activation(Activation.TANH) 激活函数为双曲正切
.weightInit(WeightInit.XAVIER) 权重初始化
.learningRate(0.1) 学习率
.regularization(true).l2(1e-4) l2正则化
.list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3) 第一层输入为4个节点,输出为3个
.build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) 输入为3个输出为3个
.build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX) 激活函数为softmax
.nIn(3).nOut(outputNum).build()) .backprop(true).pretrain(false) 反向传播
.build();
//run the model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init();
model.setListeners(new ScoreIterationListener(100)); 每迭代100次,输出一次日志
model.fit(trainingData); 开始训练 //evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatureMatrix()); 获得输入数据的特征值,并计算预测值
eval.eval(testData.getLabels(), output); 评估原始label与预测的predict
log.info(eval.stats()); 打印日志 }}
结果如下:
如有问题,请批评指正。谢谢