Deep Learning 4J 学习(二) MNIST(手写数字识别)

1.资源数据库下载

http://download.csdn.net/detail/gjq246/9794788

下载后,放到某个文件夹中,比如:D:\doc

 

 

2.创建MyMnistDataFetcher类(对官网的例子简单做了修改,不用去网络下载数据,直接读取本地路径的数据库)

该类主要是继承了BaseDataFetcher ,重点是fetch方法,实现了如何批量抽取样本数据。

/*
*
*  * Copyright 2015 Skymind,Inc.
*  *
*  *    Licensed under the Apache License, Version 2.0 (the "License");
*  *    you may not use this file except in compliance with the License.
*  *    You may obtain a copy of the License at
*  *
*  *        http://www.apache.org/licenses/LICENSE-2.0
*  *
*  *    Unless required by applicable law or agreed to in writing, software
*  *    distributed under the License is distributed on an "AS IS" BASIS,
*  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*  *    See the License for the specific language governing permissions and
*  *    limitations under the License.
*
*/

package com.jiantsing.test;

import org.apache.commons.io.FileUtils;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.datasets.fetchers.BaseDataFetcher;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;


/**
* Data fetcher for the MNIST dataset
* @author Adam Gibson
*
*/
public class MyMnistDataFetcher extends BaseDataFetcher {
    public static final int NUM_EXAMPLES = 60000;
    public static final int NUM_EXAMPLES_TEST = 10000;
    protected static final String TEMP_ROOT = System.getProperty("user.home");
    protected static final String MNIST_ROOT = "D:\\doc\\";//TEMP_ROOT + File.separator + "MNIST" + File.separator;

    protected transient MnistManager man;
    protected boolean binarize = true;
    protected boolean train;
    protected int[] order;
    protected Random rng;
    protected boolean shuffle;


    /**
     * Constructor telling whether to binarize the dataset or not
     * @param binarize whether to binarize the dataset or not
     * @throws IOException
     */
    public MyMnistDataFetcher(boolean binarize) throws IOException {
        this(binarize,true,true,System.currentTimeMillis());
    }

    public MyMnistDataFetcher(boolean binarize, boolean train, boolean shuffle, long rngSeed) throws IOException {
//         System.out.println("111");
        if(!mnistExists()) {
//             System.out.println("2222");
            new MnistFetcher().downloadAndUntar();
        }
        String images;
        String labels;
        if(train){
            images = MNIST_ROOT + MnistFetcher.trainingFilesFilename_unzipped;
            labels = MNIST_ROOT + MnistFetcher.trainingFileLabelsFilename_unzipped;
            totalExamples = NUM_EXAMPLES;
        } else {
            images = MNIST_ROOT + MnistFetcher.testFilesFilename_unzipped;
            labels = MNIST_ROOT + MnistFetcher.testFileLabelsFilename_unzipped;
            totalExamples = NUM_EXAMPLES_TEST;
        }

        try {
            man = new MnistManager(images, labels, train);
        }catch(Exception e) {
            FileUtils.deleteDirectory(new File(MNIST_ROOT));
            new MnistFetcher().downloadAndUntar();
            man = new MnistManager(images, labels, train);
        }

        numOutcomes = 10;
        this.binarize = binarize;
        cursor = 0;
        inputColumns = man.getImages().getEntryLength();
        this.train = train;
        this.shuffle = shuffle;

        if(train){
            order = new int[NUM_EXAMPLES];
        } else {
            order = new int[NUM_EXAMPLES_TEST];
        }
        for( int i=0; i<order.length; i++ ) order[i] = i;
        rng = new Random(rngSeed);
        reset();    //Shuffle order
    }

    private boolean mnistExists(){
        //Check 4 files:
        File f = new File(MNIST_ROOT,MnistFetcher.trainingFilesFilename_unzipped);
        if(!f.exists()) return false;
        f = new File(MNIST_ROOT,MnistFetcher.trainingFileLabelsFilename_unzipped);
        if(!f.exists()) return false;
        f = new File(MNIST_ROOT,MnistFetcher.testFilesFilename_unzipped);
        if(!f.exists()) return false;
        f = new File(MNIST_ROOT,MnistFetcher.testFileLabelsFilename_unzipped);
        if(!f.exists()) return false;
        return true;
    }

    public MyMnistDataFetcher() throws IOException {
        this(true);
    }

    @Override
    public void fetch(int numExamples) {
       
//         System.out.println(numExamples);
         //每一步的大小,batchSize = 128; // batch size for each epoch
        
        if(!hasMore()) {
            throw new IllegalStateException("Unable to getFromOrigin more; there are no more images");
        }


        float[][] featureData = new float[numExamples][0];
        float[][] labelData = new float[numExamples][0];

        int actualExamples = 0;
        for( int i=0; i<numExamples; i++, cursor++ ){
            if(!hasMore()) break;

            byte[] img = man.readImageUnsafe(order[cursor]);//读取图像数据
            int label = man.readLabel(order[cursor]);//读取答案,标签

            float[] featureVec = new float[img.length];
            featureData[actualExamples] = featureVec;//存储128个样本中的一个图像数据
            labelData[actualExamples] = new float[10];//初始化十个分类数组为0
            labelData[actualExamples][label] = 1.0f;//第label为答案,置为1

            for( int j=0; j<img.length; j++ ){
                //byte a = (byte)234;
                //System.out.println(a);
                //结果是-22
                //((int)a) & 0xFF=234
                float v = ((int)img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned
                if(binarize){
                    //二值化
                    if(v > 30.0f) featureVec[j] = 1.0f;
                    else featureVec[j] = 0.0f;
                } else {
                    //非二值化,默认选择这个
                    featureVec[j] = v/255.0f;
                }
            }

            actualExamples++;
        }

        if(actualExamples < numExamples){
            featureData = Arrays.copyOfRange(featureData,0,actualExamples);
            labelData = Arrays.copyOfRange(labelData,0,actualExamples);
        }

        INDArray features = Nd4j.create(featureData);
        INDArray labels = Nd4j.create(labelData);
        curr = new DataSet(features,labels);
    }

    @Override
    public void reset() {
        cursor = 0;
        curr = null;
        if(shuffle) MathUtils.shuffleArray(order, rng);
    }

    @Override
    public DataSet next() {
        DataSet next = super.next();
        return next;
    }

}

 

 


3.创建MyMnistDataSetIterator类(根据参数选择训练或者测试数据集):

 

package com.jiantsing.test;

import org.deeplearning4j.datasets.iterator.BaseDatasetIterator;

import java.io.IOException;

public class MyMnistDataSetIterator extends BaseDatasetIterator {

    public MyMnistDataSetIterator(int batch,int numExamples) throws IOException {
        this(batch,numExamples,false);
    }

    /**Get the specified number of examples for the MNIST training data set.
     * @param batch the batch size of the examples
     * @param numExamples the overall number of examples
     * @param binarize whether to binarize mnist or not
     * @throws IOException
     */
    public MyMnistDataSetIterator(int batch, int numExamples, boolean binarize) throws IOException {
        this(batch,numExamples,binarize,true,false,0);
    }

    /** Constructor to get the full MNIST data set (either test or train sets) without binarization (i.e., just normalization
     * into range of 0 to 1), with shuffling based on a random seed.
     * @param batchSize
     * @param train
     * @throws IOException
     */
    public MyMnistDataSetIterator(int batchSize, boolean train, int seed) throws IOException{
        this(batchSize, (train ? MyMnistDataFetcher.NUM_EXAMPLES : MyMnistDataFetcher.NUM_EXAMPLES_TEST), false, train, true, seed);
    }

    /**Get the specified number of MNIST examples (test or train set), with optional shuffling and binarization.
     * @param batch Size of each patch
     * @param numExamples total number of examples to load
     * @param binarize whether to binarize the data or not (if false: normalize in range 0 to 1)
     * @param train Train vs. test set
     * @param shuffle whether to shuffle the examples
     * @param rngSeed random number generator seed to use when shuffling examples
     */
    public MyMnistDataSetIterator(int batch, int numExamples, boolean binarize, boolean train, boolean shuffle, long rngSeed) throws IOException {
        super(batch, numExamples,new MyMnistDataFetcher(binarize,train,shuffle,rngSeed));
    }

}

4.测试类MyLenetMnistExample(采用Lenet卷积神经网络实现,nEpochs 增加可以适当的提高正确率)

package com.jiantsing.test;

import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MyLenetMnistExample {

    private static Logger log = LoggerFactory.getLogger(MyLenetMnistExample.class);
   
    public static void main(String[] args) throws Exception {
        int nChannels = 1; // Number of input channels
        int outputNum = 10; // The number of possible outcomes
        int batchSize = 64; // Test batch size
        int nEpochs = 1; // Number of training epochs
        int iterations = 1; // Number of training iterations
        int seed = 123; //

        /*
            Create an iterator using the batch size for one iteration
         */
        log.info("Load data....");
        DataSetIterator mnistTrain = new MyMnistDataSetIterator(batchSize,true,12345);
        DataSetIterator mnistTest = new MyMnistDataSetIterator(batchSize,false,12345);

        /*
            Construct the neural network
         */
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations) // Training iterations as above
                .regularization(true).l2(0.0005)
                /*
                    Uncomment the following for learning decay and bias
                 */
                .learningRate(.01)//.biasLearningRate(0.02)
                //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
                        .nIn(nChannels)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        //Note that nIn need not be specified in later layers
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
                        .nOut(500).build())
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(28,28,1)) //See note below
                .backprop(true).pretrain(false).build();

        /*
        Regarding the .setInputType(InputType.convolutionalFlat(28,28,1)) line: This does a few things.
        (a) It adds preprocessors, which handle things like the transition between the convolutional/subsampling layers
            and the dense layer
        (b) Does some additional configuration validation
        (c) Where necessary, sets the nIn (number of input neurons, or input depth in the case of CNNs) values for each
            layer based on the size of the previous layer (but it won't override values manually set by the user)

        InputTypes can be used with other layer types too (RNNs, MLPs etc) not just CNNs.
        For normal images (when using ImageRecordReader) use InputType.convolutional(height,width,depth).
        MNIST record reader is a special case, that outputs 28x28 pixel grayscale (nChannels=1) images, in a "flattened"
        row vector format (i.e., 1x784 vectors), hence the "convolutionalFlat" input type used here.
        */

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();


        log.info("Train model....");
        model.setListeners(new ScoreIterationListener(100));
        for( int i=0; i<nEpochs; i++ ) {
            model.fit(mnistTrain);
            log.info("*** Completed epoch {} ***", i);

            log.info("Evaluate model....");
            Evaluation eval = new Evaluation(outputNum);
            while(mnistTest.hasNext()){
                DataSet ds = mnistTest.next();
                INDArray output = model.output(ds.getFeatureMatrix(), false);
                eval.eval(ds.getLabels(), output);

            }
            log.info(eval.stats());
            mnistTest.reset();
        }
        log.info("****************Example finished********************");
    }

}

 

 

5.运行结果(正确率0.9754):

 

08:49:34,362 INFO  ~ Load data....


08:49:34,997 INFO  ~ Build model....
08:49:35,078 INFO  ~ Loaded [CpuBackend] backend
08:49:35,237 INFO  ~ Number of threads used for NativeOps: 4
08:49:40,622 WARN  ~ Warning: new network default sets pretrain to false.
08:49:40,622 WARN  ~ Warning: new network default sets backprop to true.
08:49:40,854 INFO  ~ Train model....
08:49:41,024 INFO  ~ Number of threads used for BLAS: 4
08:49:41,198 INFO  ~ Score at iteration 0 is 2.3103652104128067
08:50:03,545 INFO  ~ Score at iteration 100 is 0.12810337708838423
08:50:19,942 INFO  ~ Score at iteration 200 is 0.19015166324069205
08:50:33,431 INFO  ~ Score at iteration 300 is 0.049264141885328044
08:50:48,218 INFO  ~ Score at iteration 400 is 0.11165154513286084
08:51:03,657 INFO  ~ Score at iteration 500 is 0.1541011125238029
08:51:16,916 INFO  ~ Score at iteration 600 is 0.08863151381200945
08:51:30,621 INFO  ~ Score at iteration 700 is 0.02344464039862945
08:51:43,909 INFO  ~ Score at iteration 800 is 0.08253156480036103
08:51:57,386 INFO  ~ Score at iteration 900 is 0.013300788175833765
08:52:01,680 INFO  ~ *** Completed epoch 0 ***
08:52:01,680 INFO  ~ Evaluate model....
08:52:07,643 INFO  ~
Examples labeled as 0 classified by model as 0: 969 times
Examples labeled as 0 classified by model as 1: 1 times
Examples labeled as 0 classified by model as 2: 2 times
Examples labeled as 0 classified by model as 5: 2 times
Examples labeled as 0 classified by model as 6: 2 times
Examples labeled as 0 classified by model as 7: 3 times
Examples labeled as 0 classified by model as 8: 1 times
Examples labeled as 1 classified by model as 1: 1122 times
Examples labeled as 1 classified by model as 2: 5 times
Examples labeled as 1 classified by model as 3: 1 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 2 times
Examples labeled as 1 classified by model as 8: 4 times
Examples labeled as 2 classified by model as 2: 1026 times
Examples labeled as 2 classified by model as 3: 1 times
Examples labeled as 2 classified by model as 7: 4 times
Examples labeled as 2 classified by model as 8: 1 times
Examples labeled as 3 classified by model as 2: 10 times
Examples labeled as 3 classified by model as 3: 964 times
Examples labeled as 3 classified by model as 5: 24 times
Examples labeled as 3 classified by model as 7: 8 times
Examples labeled as 3 classified by model as 8: 3 times
Examples labeled as 3 classified by model as 9: 1 times
Examples labeled as 4 classified by model as 2: 1 times
Examples labeled as 4 classified by model as 4: 978 times
Examples labeled as 4 classified by model as 6: 2 times
Examples labeled as 4 classified by model as 7: 1 times
Examples labeled as 5 classified by model as 0: 1 times
Examples labeled as 5 classified by model as 3: 2 times
Examples labeled as 5 classified by model as 5: 887 times
Examples labeled as 5 classified by model as 6: 1 times
Examples labeled as 5 classified by model as 7: 1 times
Examples labeled as 6 classified by model as 0: 6 times
Examples labeled as 6 classified by model as 1: 4 times
Examples labeled as 6 classified by model as 2: 2 times
Examples labeled as 6 classified by model as 3: 1 times
Examples labeled as 6 classified by model as 4: 3 times
Examples labeled as 6 classified by model as 5: 13 times
Examples labeled as 6 classified by model as 6: 929 times
Examples labeled as 7 classified by model as 1: 4 times
Examples labeled as 7 classified by model as 2: 18 times
Examples labeled as 7 classified by model as 7: 1005 times
Examples labeled as 7 classified by model as 9: 1 times
Examples labeled as 8 classified by model as 0: 2 times
Examples labeled as 8 classified by model as 2: 5 times
Examples labeled as 8 classified by model as 3: 3 times
Examples labeled as 8 classified by model as 4: 3 times
Examples labeled as 8 classified by model as 5: 12 times
Examples labeled as 8 classified by model as 6: 1 times
Examples labeled as 8 classified by model as 7: 14 times
Examples labeled as 8 classified by model as 8: 925 times
Examples labeled as 8 classified by model as 9: 9 times
Examples labeled as 9 classified by model as 0: 2 times
Examples labeled as 9 classified by model as 1: 4 times
Examples labeled as 9 classified by model as 2: 3 times
Examples labeled as 9 classified by model as 3: 1 times
Examples labeled as 9 classified by model as 4: 24 times
Examples labeled as 9 classified by model as 5: 13 times
Examples labeled as 9 classified by model as 7: 12 times
Examples labeled as 9 classified by model as 8: 1 times
Examples labeled as 9 classified by model as 9: 949 times


==========================Scores========================================
Accuracy:        0.9754
Precision:       0.9755
Recall:          0.9754
F1 Score:        0.9755
========================================================================
08:52:07,644 INFO  ~ ****************Example finished********************

 

6.完整代码:

https://github.com/gjq246/deeplearning4jtest

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值