Deep Learning 4J 学习(二) MNIST(手写数字识别)

1.资源数据库下载

http://download.csdn.net/detail/gjq246/9794788

下载后,放到某个文件夹中,比如:D:\doc

 

 

2.创建MyMnistDataFetcher类(对官网的例子简单做了修改,不用去网络下载数据,直接读取本地路径的数据库)

该类主要是继承了BaseDataFetcher ,重点是fetch方法,实现了如何批量抽取样本数据。

/*
*
*  * Copyright 2015 Skymind,Inc.
*  *
*  *    Licensed under the Apache License, Version 2.0 (the "License");
*  *    you may not use this file except in compliance with the License.
*  *    You may obtain a copy of the License at
*  *
*  *        http://www.apache.org/licenses/LICENSE-2.0
*  *
*  *    Unless required by applicable law or agreed to in writing, software
*  *    distributed under the License is distributed on an "AS IS" BASIS,
*  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*  *    See the License for the specific language governing permissions and
*  *    limitations under the License.
*
*/

package com.jiantsing.test;

import org.apache.commons.io.FileUtils;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.datasets.fetchers.BaseDataFetcher;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;


/**
* Data fetcher for the MNIST dataset
* @author Adam Gibson
*
*/
public class MyMnistDataFetcher extends BaseDataFetcher {
    public static final int NUM_EXAMPLES = 60000;
    public static final int NUM_EXAMPLES_TEST = 10000;
    protected static final String TEMP_ROOT = System.getProperty("user.home");
    protected static final String MNIST_ROOT = "D:\\doc\\";//TEMP_ROOT + File.separator + "MNIST" + File.separator;

    protected transient MnistManager man;
    protected boolean binarize = true;
    protected boolean train;
    protected int[] order;
    protected Random rng;
    protected boolean shuffle;


    /**
     * Constructor telling whether to binarize the dataset or not
     * @param binarize whether to binarize the dataset or not
     * @throws IOException
     */
    public MyMnistDataFetcher(boolean binarize) throws IOException {
        this(binarize,true,true,System.currentTimeMillis());
    }

    public MyMnistDataFetcher(boolean binarize, boolean train, boolean shuffle, long rngSeed) throws IOException {
//         System.out.println("111");
        if(!mnistExists()) {
//             System.out.println("2222");
            new MnistFetcher().downloadAndUntar();
        }
        String images;
        String labels;
        if(train){
            images = MNIST_ROOT + MnistFetcher.trainingFilesFilename_unzipped;
            labels = MNIST_ROOT + MnistFetcher.trainingFileLabelsFilename_unzipped;
            totalExamples = NUM_EXAMPLES;
        } else {
            images = MNIST_ROOT + MnistFetcher.testFilesFilename_unzipped;
            labels = MNIST_ROOT + MnistFetcher.testFileLabelsFilename_unzipped;
            totalExamples = NUM_EXAMPLES_TEST;
        }

        try {
            man = new MnistManager(images, labels, train);
        }catch(Exception e) {
            FileUtils.deleteDirectory(new File(MNIST_ROOT));
            new MnistFetcher().downloadAndUntar();
            man = new MnistManager(images, labels, train);
        }

        numOutcomes = 10;
        this.binarize = binarize;
        cursor = 0;
        inputColumns = man.getImages().getEntryLength();
        this.train = train;
        this.shuffle = shuffle;

        if(train){
            order = new int[NUM_EXAMPLES];
        } else {
            order = new int[NUM_EXAMPLES_TEST];
        }
        for( int i=0; i<order.length; i++ ) order[i] = i;
        rng = new Random(rngSeed);
        reset();    //Shuffle order
    }

    private boolean mnistExists(){
        //Check 4 files:
        File f = new File(MNIST_ROOT,MnistFetcher.trainingFilesFilename_unzipped);
        if(!f.exists()) return false;
        f = new File(MNIST_ROOT,MnistFetcher.trainingFileLabelsFilename_unzipped);
        if(!f.exists()) return false;
        f = new File(MNIST_ROOT,MnistFetcher.testFilesFilename_unzipped);
        if(!f.exists()) return false;
        f = new File(MNIST_ROOT,MnistFetcher.testFileLabelsFilename_unzipped);
        if(!f.exists()) return false;
        return true;
    }

    public MyMnistDataFetcher() throws IOException {
        this(true);
    }

    @Override
    public void fetch(int numExamples) {
       
//         System.out.println(numExamples);
         //每一步的大小,batchSize = 128; // batch size for each epoch
        
        if(!hasMore()) {
            throw new IllegalStateException("Unable to getFromOrigin more; there are no more images");
        }


        float[][] featureData = new float[numExamples][0];
        float[][] labelData = new float[numExamples][0];

        int actualExamples = 0;
        for( int i=0; i<numExamples; i++, cursor++ ){
            if(!hasMore()) break;

            byte[] img = man.readImageUnsafe(order[cursor]);//读取图像数据
            int label = man.readLabel(order[cursor]);//读取答案,标签

            float[] featureVec = new float[img.length];
            featureData[actualExamples] = featureVec;//存储128个样本中的一个图像数据
            labelData[actualExamples] = new float[10];//初始化十个分类数组为0
            labelData[actualExamples][label] = 1.0f;//第label为答案,置为1

            for( int j=0; j<img.length; j++ ){
                //byte a = (byte)234;
                //System.out.println(a);
                //结果是-22
                //((int)a) & 0xFF=234
                float v = ((int)img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned
                if(binarize){
                    //二值化
                    if(v > 30.0f) featureVec[j] = 1.0f;
                    else featureVec[j] = 0.0f;
                } else {
                    //非二值化,默认选择这个
                    featureVec[j] = v/255.0f;
                }
            }

            actualExamples++;
        }

        if(actualExamples < numExamples){
            featureData = Arrays.copyOfRange(featureData,0,actualExamples);
            labelData = Arrays.copyOfRange(labelData,0,actualExamples);
        }

        INDArray features = Nd4j.create(featureData);
        INDArray labels = Nd4j.create(labelData);
        curr = new DataSet(features,labels);
    }

    @Override
    public void reset() {
        cursor = 0;
        curr = null;
        if(shuffle) MathUtils.shuffleArray(order, rng);
    }

    @Override
    public DataSet next() {
        DataSet next = super.next();
        return next;
    }

}

 

 


3.创建MyMnistDataSetIterator类(根据参数选择训练或者测试数据集):

 

package com.jiantsing.test;

import org.deeplearning4j.datasets.iterator.BaseDatasetIterator;

import java.io.IOException;

public class MyMnistDataSetIterator extends BaseDatasetIterator {

    public MyMnistDataSetIterator(int batch,int numExamples) throws IOException {
        this(batch,numExamples,false);
    }

    /**Get the specified number of examples for the MNIST training data set.
     * @param batch the batch size of the examples
     * @param numExamples the overall number of examples
     * @param binarize whether to binarize mnist or not
     * @throws IOException
     */
    public MyMnistDataSetIterator(int batch, int numExamples, boolean binarize) throws IOException {
        this(batch,numExamples,binarize,true,false,0);
    }

    /** Constructor to get the full MNIST data set (either test or train sets) without binarization (i.e., just normalization
     * into range of 0 to 1), with shuffling based on a random seed.
     * @param batchSize
     * @param train
     * @throws IOException
     */
    public MyMnistDataSetIterator(int batchSize, boolean train, int seed) throws IOException{
        this(batchSize, (train ? MyMnistDataFetcher.NUM_EXAMPLES : MyMnistDataFetcher.NUM_EXAMPLES_TEST), false, train, true, seed);
    }

    /**Get the specified number of MNIST examples (test or train set), with optional shuffling and binarization.
     * @param batch Size of each patch
     * @param numExamples total number of examples to load
     * @param binarize whether to binarize the data or not (if false: normalize in range 0 to 1)
     * @param train Train vs. test set
     * @param shuffle whether to shuffle the examples
     * @param rngSeed random number generator seed to use when shuffling examples
     */
    public MyMnistDataSetIterator(int batch, int numExamples, boolean binarize, boolean train, boolean shuffle, long rngSeed) throws IOException {
        super(batch, numExamples,new MyMnistDataFetcher(binarize,train,shuffle,rngSeed));
    }

}

4.测试类MyLenetMnistExample(采用Lenet卷积神经网络实现,nEpochs 增加可以适当的提高正确率)

package com.jiantsing.test;

import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MyLenetMnistExample {

    private static Logger log = LoggerFactory.getLogger(MyLenetMnistExample.class);
   
    public static void main(String[] args) throws Exception {
        int nChannels = 1; // Number of input channels
        int outputNum = 10; // The number of possible outcomes
        int batchSize = 64; // Test batch size
        int nEpochs = 1; // Number of training epochs
        int iterations = 1; // Number of training iterations
        int seed = 123; //

        /*
            Create an iterator using the batch size for one iteration
         */
        log.info("Load data....");
        DataSetIterator mnistTrain = new MyMnistDataSetIterator(batchSize,true,12345);
        DataSetIterator mnistTest = new MyMnistDataSetIterator(batchSize,false,12345);

        /*
            Construct the neural network
         */
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations) // Training iterations as above
                .regularization(true).l2(0.0005)
                /*
                    Uncomment the following for learning decay and bias
                 */
                .learningRate(.01)//.biasLearningRate(0.02)
                //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
                        .nIn(nChannels)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        //Note that nIn need not be specified in later layers
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
                        .nOut(500).build())
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(28,28,1)) //See note below
                .backprop(true).pretrain(false).build();

        /*
        Regarding the .setInputType(InputType.convolutionalFlat(28,28,1)) line: This does a few things.
        (a) It adds preprocessors, which handle things like the transition between the convolutional/subsampling layers
            and the dense layer
        (b) Does some additional configuration validation
        (c) Where necessary, sets the nIn (number of input neurons, or input depth in the case of CNNs) values for each
            layer based on the size of the previous layer (but it won't override values manually set by the user)

        InputTypes can be used with other layer types too (RNNs, MLPs etc) not just CNNs.
        For normal images (when using ImageRecordReader) use InputType.convolutional(height,width,depth).
        MNIST record reader is a special case, that outputs 28x28 pixel grayscale (nChannels=1) images, in a "flattened"
        row vector format (i.e., 1x784 vectors), hence the "convolutionalFlat" input type used here.
        */

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();


        log.info("Train model....");
        model.setListeners(new ScoreIterationListener(100));
        for( int i=0; i<nEpochs; i++ ) {
            model.fit(mnistTrain);
            log.info("*** Completed epoch {} ***", i);

            log.info("Evaluate model....");
            Evaluation eval = new Evaluation(outputNum);
            while(mnistTest.hasNext()){
                DataSet ds = mnistTest.next();
                INDArray output = model.output(ds.getFeatureMatrix(), false);
                eval.eval(ds.getLabels(), output);

            }
            log.info(eval.stats());
            mnistTest.reset();
        }
        log.info("****************Example finished********************");
    }

}

 

 

5.运行结果(正确率0.9754):

 

08:49:34,362 INFO  ~ Load data....


08:49:34,997 INFO  ~ Build model....
08:49:35,078 INFO  ~ Loaded [CpuBackend] backend
08:49:35,237 INFO  ~ Number of threads used for NativeOps: 4
08:49:40,622 WARN  ~ Warning: new network default sets pretrain to false.
08:49:40,622 WARN  ~ Warning: new network default sets backprop to true.
08:49:40,854 INFO  ~ Train model....
08:49:41,024 INFO  ~ Number of threads used for BLAS: 4
08:49:41,198 INFO  ~ Score at iteration 0 is 2.3103652104128067
08:50:03,545 INFO  ~ Score at iteration 100 is 0.12810337708838423
08:50:19,942 INFO  ~ Score at iteration 200 is 0.19015166324069205
08:50:33,431 INFO  ~ Score at iteration 300 is 0.049264141885328044
08:50:48,218 INFO  ~ Score at iteration 400 is 0.11165154513286084
08:51:03,657 INFO  ~ Score at iteration 500 is 0.1541011125238029
08:51:16,916 INFO  ~ Score at iteration 600 is 0.08863151381200945
08:51:30,621 INFO  ~ Score at iteration 700 is 0.02344464039862945
08:51:43,909 INFO  ~ Score at iteration 800 is 0.08253156480036103
08:51:57,386 INFO  ~ Score at iteration 900 is 0.013300788175833765
08:52:01,680 INFO  ~ *** Completed epoch 0 ***
08:52:01,680 INFO  ~ Evaluate model....
08:52:07,643 INFO  ~
Examples labeled as 0 classified by model as 0: 969 times
Examples labeled as 0 classified by model as 1: 1 times
Examples labeled as 0 classified by model as 2: 2 times
Examples labeled as 0 classified by model as 5: 2 times
Examples labeled as 0 classified by model as 6: 2 times
Examples labeled as 0 classified by model as 7: 3 times
Examples labeled as 0 classified by model as 8: 1 times
Examples labeled as 1 classified by model as 1: 1122 times
Examples labeled as 1 classified by model as 2: 5 times
Examples labeled as 1 classified by model as 3: 1 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 2 times
Examples labeled as 1 classified by model as 8: 4 times
Examples labeled as 2 classified by model as 2: 1026 times
Examples labeled as 2 classified by model as 3: 1 times
Examples labeled as 2 classified by model as 7: 4 times
Examples labeled as 2 classified by model as 8: 1 times
Examples labeled as 3 classified by model as 2: 10 times
Examples labeled as 3 classified by model as 3: 964 times
Examples labeled as 3 classified by model as 5: 24 times
Examples labeled as 3 classified by model as 7: 8 times
Examples labeled as 3 classified by model as 8: 3 times
Examples labeled as 3 classified by model as 9: 1 times
Examples labeled as 4 classified by model as 2: 1 times
Examples labeled as 4 classified by model as 4: 978 times
Examples labeled as 4 classified by model as 6: 2 times
Examples labeled as 4 classified by model as 7: 1 times
Examples labeled as 5 classified by model as 0: 1 times
Examples labeled as 5 classified by model as 3: 2 times
Examples labeled as 5 classified by model as 5: 887 times
Examples labeled as 5 classified by model as 6: 1 times
Examples labeled as 5 classified by model as 7: 1 times
Examples labeled as 6 classified by model as 0: 6 times
Examples labeled as 6 classified by model as 1: 4 times
Examples labeled as 6 classified by model as 2: 2 times
Examples labeled as 6 classified by model as 3: 1 times
Examples labeled as 6 classified by model as 4: 3 times
Examples labeled as 6 classified by model as 5: 13 times
Examples labeled as 6 classified by model as 6: 929 times
Examples labeled as 7 classified by model as 1: 4 times
Examples labeled as 7 classified by model as 2: 18 times
Examples labeled as 7 classified by model as 7: 1005 times
Examples labeled as 7 classified by model as 9: 1 times
Examples labeled as 8 classified by model as 0: 2 times
Examples labeled as 8 classified by model as 2: 5 times
Examples labeled as 8 classified by model as 3: 3 times
Examples labeled as 8 classified by model as 4: 3 times
Examples labeled as 8 classified by model as 5: 12 times
Examples labeled as 8 classified by model as 6: 1 times
Examples labeled as 8 classified by model as 7: 14 times
Examples labeled as 8 classified by model as 8: 925 times
Examples labeled as 8 classified by model as 9: 9 times
Examples labeled as 9 classified by model as 0: 2 times
Examples labeled as 9 classified by model as 1: 4 times
Examples labeled as 9 classified by model as 2: 3 times
Examples labeled as 9 classified by model as 3: 1 times
Examples labeled as 9 classified by model as 4: 24 times
Examples labeled as 9 classified by model as 5: 13 times
Examples labeled as 9 classified by model as 7: 12 times
Examples labeled as 9 classified by model as 8: 1 times
Examples labeled as 9 classified by model as 9: 949 times


==========================Scores========================================
Accuracy:        0.9754
Precision:       0.9755
Recall:          0.9754
F1 Score:        0.9755
========================================================================
08:52:07,644 INFO  ~ ****************Example finished********************

 

6.完整代码:

https://github.com/gjq246/deeplearning4jtest

发布了79 篇原创文章 · 获赞 29 · 访问量 27万+
展开阅读全文

关于Deeplearning4j 官方给的ImagepipelineExample中测试单个图片的问题

05-10

代码如下: ``` public class ImagePipelineExample { protected static final Logger log = LoggerFactory.getLogger(ImagePipelineExample.class); //Images are of format given by allowedExtension - protected static final String [] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS; protected static final long seed = 12345; public static final Random randNumGen = new Random(seed); protected static int height = 50; protected static int width = 50; protected static int channels = 3; public static void main(String[] args) throws Exception { //DIRECTORY STRUCTURE: //Images in the dataset have to be organized in directories by class/label. //In this example there are ten images in three classes //Here is the directory structure // parentDir // / | \ // / | \ // labelB labelB labelC // //Set your data up like this so that labels from each label/class live in their own directory //And these label/class directories live together in the parent directory ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker); ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker); File trainDir = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/DataExamples/ImagePipeline/"); InputSplit trainData1=new FileSplit(trainDir); recordReader.initialize(trainData1); int outputNum = recordReader.numLabels(); DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, 10, 1, outputNum); while(dataIter.hasNext()){ DataSet ds=dataIter.next(); System.out.println("train:"+ds); } recordReader.reset(); System.out.println("train Finished!"); File testDir = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/DataExamples/testlabel/"); InputSplit testData1=new FileSplit(testDir); recordReader.initialize(testData1); DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader,10,1,outputNum);//生成测试迭代数据 while (testIter.hasNext()){ DataSet ds = dataIter.next(); System.out.println("test:"+ds); } recordReader.reset(); } } ``` 这里有个测试集的目录,我在这个目录下放了一张花的图片。 就直接报错了。 ``` train Finished! Labels:[labelA, labelB, labelC, testlabel] Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 3 at org.nd4j.linalg.util.FeatureUtil.toOutcomeVector(FeatureUtil.java:38) at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReaderDataSetIterator.java:234) at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:186) at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:389) at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:52) at org.deeplearning4j.examples.dataexamples.ImagePipelineExample.main(ImagePipelineExample.java:91) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144) Process finished with exit code 1 ``` 可以看出来,它把测试集的目录 也当成一个output了。 请问,要如何测试单张图片的output呢? 问答

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 精致技术 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览