Deep Learning 4J 学习(四) 使用CSV文件来读取训练数据

1.BasicCSVClassifier.java:

 

package com.jiantsing.test;


import org.apache.commons.io.IOUtils;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

 

/**
* This example is intended to be a simple CSV classifier that seperates the training data
* from the test data for the classification of animals. It would be suitable as a beginner's
* example because not only does it load CSV data into the network, it also shows how to extract the
* data and display the results of the classification, as well as a simple method to map the lables
* from the testing data into the results.
*
* @author Clay Graham
*/
public class BasicCSVClassifier {

    private static Logger log = LoggerFactory.getLogger(BasicCSVClassifier.class);

    private static Map<Integer,String> eats = readEnumCSV("/animals/eats.csv");
    private static Map<Integer,String> sounds = readEnumCSV("/animals/sounds.csv");
    private static Map<Integer,String> classifiers = readEnumCSV("/animals/classifiers.csv");

    public static void main(String[] args){

        try {

            //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
            //每行数据5个值,前面4个为输入特征,最后一个也就是下标为4的为答案
            int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
            //有3个分类,分别为0,1,2也就是上面的下标为4的取值范围
            int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

            int batchSizeTraining = 30;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
            DataSet trainingData = readCSVDataset(
                    "/animals/animals_train.csv",
                    batchSizeTraining, labelIndex, numClasses);

            // this is the data we want to classify
            int batchSizeTest = 44;
            //原来的animals.csv第四列都为0,作为测试数据是错误的,后面自己标注上去了
            DataSet testData = readCSVDataset("/animals/animals.csv",
                    batchSizeTest, labelIndex, numClasses);


            // make the data model for records prior to normalization, because it
            // changes the data.
            Map<Integer,Map<String,Object>> animals = makeAnimalsForTesting(testData);
            //下标,map{年龄,食物,声音,重量}


            //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
            //对于不是0-1的要做标准化数据处理
            DataNormalization normalizer = new NormalizerStandardize();
            //获取 STDEV 基于样本估算标准偏差
            normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
            //标准化数据,应该是要转为0-1的float???
            normalizer.transform(trainingData);     //Apply normalization to the training data
            normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

            final int numInputs = 4;//输入特征数
            int outputNum = 3;//输出分类个数
            int iterations = 1000;
            long seed = 6;
           
            int numEpochs = 1; // number of epochs to perform

            log.info("Build model....");
            MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .seed(seed)
                    .iterations(iterations)
                    .activation(Activation.TANH)
                    .weightInit(WeightInit.XAVIER)
                    .learningRate(0.1)
                    .regularization(true).l2(1e-4)
                    .list()
                    .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3).build())
                    .layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
                    .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                            .activation(Activation.SOFTMAX).nIn(3).nOut(outputNum).build())
                    .backprop(true).pretrain(false)
                    .build();

            //run the model
            MultiLayerNetwork model = new MultiLayerNetwork(conf);
            model.init();
            model.setListeners(new ScoreIterationListener(100));

            log.info("Train model....");
            for( int i=0; i<numEpochs; i++ ){
                model.fit(trainingData);
            }
           
            //evaluate the model on the test set
            Evaluation eval = new Evaluation(3);
            INDArray output = model.output(testData.getFeatureMatrix());

            eval.eval(testData.getLabels(), output);
            log.info("-------------------------------------------------------");
            log.info(eval.stats());
            log.info("-------------------------------------------------------");

            setFittedClassifiers(output, animals);
            logAnimals(animals);

        } catch (Exception e){
            e.printStackTrace();
        }

    }

 

    public static void logAnimals(Map<Integer,Map<String,Object>> animals){
        for(Map<String,Object> a:animals.values())
            log.info(a.toString());//打印分类结果
    }

    public static void setFittedClassifiers(INDArray output, Map<Integer,Map<String,Object>> animals){
        for (int i = 0; i < output.rows() ; i++) {

//            log.info("-------------");
//            System.out.println(getFloatArrayFromSlice(output.slice(i))[0]);
//            log.info("-------------");
            // set the classification from the fitted results
            animals.get(i).put("classifier",
                    //classifiers是csv下标对照表,把分类id对应成名称,output.slice(i)=[0.00, 0.99, 0.01],对应每种分类的得分,最高值即为分类
                    classifiers.get(maxIndex(getFloatArrayFromSlice(output.slice(i)))));

        }

    }


    /**
     * This method is to show how to convert the INDArray to a float array. This is to
     * provide some more examples on how to convert INDArray to types that are more java
     * centric.
     *
     * @param rowSlice
     * @return
     */
    public static float[] getFloatArrayFromSlice(INDArray rowSlice){
        float[] result = new float[rowSlice.columns()];
        for (int i = 0; i < rowSlice.columns(); i++) {
            result[i] = rowSlice.getFloat(i);
        }
        return result;
    }

    /**
     * find the maximum item index. This is used when the data is fitted and we
     * want to determine which class to assign the test row to
     *
     * @param vals
     * @return
     */
    public static int maxIndex(float[] vals){
        int maxIndex = 0;
        for (int i = 1; i < vals.length; i++){
            float newnumber = vals[i];
            if ((newnumber > vals[maxIndex])){
                maxIndex = i;
            }
        }
        return maxIndex;
    }

    /**
     * take the dataset loaded for the matric and make the record model out of it so
     * we can correlate the fitted classifier to the record.
     *
     * @param testData
     * @return
     */
    public static Map<Integer,Map<String,Object>> makeAnimalsForTesting(DataSet testData){
        Map<Integer,Map<String,Object>> animals = new HashMap<>();

        INDArray features = testData.getFeatureMatrix();
        for (int i = 0; i < features.rows() ; i++) {
            INDArray slice = features.slice(i);
            Map<String,Object> animal = new HashMap();

            //set the attributes,对应输入的4个特征值
            animal.put("yearsLived", slice.getInt(0));
            //eats,sounds,从列表映射过来
            animal.put("eats", eats.get(slice.getInt(1)));
            animal.put("sounds", sounds.get(slice.getInt(2)));
            animal.put("weight", slice.getFloat(3));

            animals.put(i,animal);
        }
        return animals;

    }


    public static Map<Integer,String> readEnumCSV(String csvFileClasspath) {
        try{
            List<String> lines = IOUtils.readLines(new ClassPathResource(csvFileClasspath).getInputStream());
            Map<Integer,String> enums = new HashMap<>();
            for(String line:lines){
                String[] parts = line.split(",");
                enums.put(Integer.parseInt(parts[0]),parts[1]);
            }
            return enums;
        } catch (Exception e){
            e.printStackTrace();
            return null;
        }

    }

    /**
     * used for testing and training
     *
     * @param csvFileClasspath
     * @param batchSize
     * @param labelIndex
     * @param numClasses
     * @return
     * @throws IOException
     * @throws InterruptedException
     */
    private static DataSet readCSVDataset(
            String csvFileClasspath, int batchSize, int labelIndex, int numClasses)
            throws IOException, InterruptedException{

        RecordReader rr = new CSVRecordReader();
        rr.initialize(new FileSplit(new ClassPathResource(csvFileClasspath).getFile()));
        DataSetIterator iterator = new RecordReaderDataSetIterator(rr,batchSize,labelIndex,numClasses);
        return iterator.next();//一下子取出30个样本,也就是全部的数据了
    }

 

}

2.CSV文件,自带例子中的animals.csv,最后一列缺少答案,后来自己补上去了。

3.运行结果:

22:51:50,272 INFO  ~ Loaded [CpuBackend] backend
22:51:50,598 INFO  ~ Number of threads used for NativeOps: 4
22:51:51,752 INFO  ~ Build model....
22:52:01,389 WARN  ~ Warning: new network default sets pretrain to false.
22:52:01,390 WARN  ~ Warning: new network default sets backprop to true.
22:52:03,265 INFO  ~ Train model....
22:52:03,823 INFO  ~ Number of threads used for BLAS: 4
22:52:03,888 INFO  ~ Score at iteration 0 is 1.225375238425816
22:52:04,281 INFO  ~ Score at iteration 100 is 0.4846017448250256
22:52:04,665 INFO  ~ Score at iteration 200 is 0.104313639736892
22:52:04,871 INFO  ~ Score at iteration 300 is 0.0474590900935748
22:52:05,135 INFO  ~ Score at iteration 400 is 0.02975828966283905
22:52:05,352 INFO  ~ Score at iteration 500 is 0.02133654741365093
22:52:05,491 INFO  ~ Score at iteration 600 is 0.016472987034082367
22:52:05,645 INFO  ~ Score at iteration 700 is 0.013329871882238171
22:52:05,799 INFO  ~ Score at iteration 800 is 0.011143243188705208
22:52:05,919 INFO  ~ Score at iteration 900 is 0.009540781924404352
22:52:06,152 INFO  ~ -------------------------------------------------------
22:52:06,156 INFO  ~
Examples labeled as 0 classified by model as 0: 8 times
Examples labeled as 1 classified by model as 1: 13 times
Examples labeled as 2 classified by model as 2: 23 times


==========================Scores========================================
Accuracy:        1
Precision:       1
Recall:          1
F1 Score:        1
========================================================================
22:52:06,156 INFO  ~ -------------------------------------------------------
22:52:06,161 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=10.0, yearsLived=19}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=12.0, yearsLived=17}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=16.0, yearsLived=15}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=10.0, yearsLived=19}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=12.0, yearsLived=17}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=16.0, yearsLived=15}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=10.0, yearsLived=19}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=12.0, yearsLived=17}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=16.0, yearsLived=15}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=12.0, yearsLived=17}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=12.0, yearsLived=17}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=16.0, yearsLived=20}
22:52:06,162 INFO  ~ {eats=Mice, sounds=Meow, classifier=Cat, weight=23.0, yearsLived=19}
22:52:06,163 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=60.0, yearsLived=9}
22:52:06,163 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=50.0, yearsLived=9}
22:52:06,163 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=70.0, yearsLived=9}
22:52:06,163 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=70.0, yearsLived=7}
22:52:06,163 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=40.0, yearsLived=10}
22:52:06,163 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=60.0, yearsLived=9}
22:52:06,163 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=50.0, yearsLived=9}
22:52:06,165 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=70.0, yearsLived=9}
22:52:06,165 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=70.0, yearsLived=7}
22:52:06,165 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=40.0, yearsLived=10}
22:52:06,165 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=60.0, yearsLived=9}
22:52:06,165 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=50.0, yearsLived=9}
22:52:06,165 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=70.0, yearsLived=9}
22:52:06,166 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=70.0, yearsLived=7}
22:52:06,166 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=40.0, yearsLived=10}
22:52:06,166 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=70.0, yearsLived=7}
22:52:06,166 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=62.0, yearsLived=9}
22:52:06,166 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=70.0, yearsLived=7}
22:52:06,166 INFO  ~ {eats=Cats, sounds=Bark, classifier=Dog, weight=14.0, yearsLived=9}
22:52:06,166 INFO  ~ {eats=Hotdogs, sounds=Talking, classifier=Human, weight=200.0, yearsLived=80}
22:52:06,166 INFO  ~ {eats=Hotdogs, sounds=Talking, classifier=Human, weight=170.0, yearsLived=65}
22:52:06,166 INFO  ~ {eats=Hotdogs, sounds=Talking, classifier=Human, weight=170.0, yearsLived=65}
22:52:06,166 INFO  ~ {eats=Hotdogs, sounds=Talking, classifier=Human, weight=170.0, yearsLived=65}
22:52:06,166 INFO  ~ {eats=Hotdogs, sounds=Talking, classifier=Human, weight=170.0, yearsLived=44}
22:52:06,166 INFO  ~ {eats=Hotdogs, sounds=Bark, classifier=Dog, weight=65.0, yearsLived=8}
22:52:06,166 INFO  ~ {eats=Hotdogs, sounds=Talking, classifier=Human, weight=150.0, yearsLived=72}
22:52:06,166 INFO  ~ {eats=Ice Cream, sounds=Bark, classifier=Dog, weight=100.0, yearsLived=11}
22:52:06,166 INFO  ~ {eats=Ice Cream, sounds=Bark, classifier=Dog, weight=100.0, yearsLived=11}
22:52:06,166 INFO  ~ {eats=Ice Cream, sounds=Bark, classifier=Dog, weight=100.0, yearsLived=11}
22:52:06,166 INFO  ~ {eats=Ice Cream, sounds=Singing, classifier=Human, weight=160.0, yearsLived=55}
22:52:06,166 INFO  ~ {eats=Ice Cream, sounds=Singing, classifier=Human, weight=300.0, yearsLived=80}

4.所有源码:

https://github.com/gjq246/deeplearning4jtest

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值