利用dl4j识别图像颜色

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.DoublesDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.examples.dataexamples.CSVExample;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.deeplearning4j.util.DeepLearningIOUtil;
/**
 * Created by Administrator on 2017/10/27 0027.
 */

    public class colorRecognize {

        private static Logger log = LoggerFactory.getLogger(org.deeplearning4j.examples.dataexamples.CSVExample.class);

        public static void main(String[] args) throws  Exception {

           // MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork("E:/custom_images/colorRecognize.bin");
            //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
            LineNumberReader  lnr = new LineNumberReader(new FileReader(new File("E:\\DataExamples\\color1.txt")));
            lnr.skip(Long.MAX_VALUE);
            int rows = lnr.getLineNumber() + 1;

            lnr.close();
            int numLinesToSkip = 0;
            char delimiter = ',';
            RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
            recordReader.initialize(new FileSplit(new File("E:\\DataExamples\\color1.txt")));

            //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
            int labelIndex = 3;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
            int numClasses = 8;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
            int batchSize = rows;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

            DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
            DataSet allData = iterator.next();

            allData.shuffle();
            SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.7);  //Use 65% of data for training
            final int numInputs = 3;
            int outputNum = 8;
            int iterations = 2;
            long seed = 126;

            DataSet trainingData = testAndTrain.getTrain();
            DataSet testData = testAndTrain.getTest();
            List<DataSet> trainList = trainingData.asList();
            DataSetIterator trainIterator = new ListDataSetIterator(trainList,10);
            //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
             DataNormalization normalizer = new NormalizerStandardize();
            normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
//            //  normalizer.transform(trainingData);     //Apply normalization to the training data
//          //  normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
//
            NormalizerSerializer.getDefault().write(normalizer,"E:/mean.txt");
            DataNormalization norm1 = NormalizerSerializer.getDefault().restore("E:/mean.txt");
           // NormalizerStandardizeSerializer.write(normalizer, "E:/mean.txt");
            NormalizerStandardize norm = (NormalizerStandardize)normalizer;

//            INDArray means = norm.getMean();
//            INDArray stv = norm.getStd();
//            OutputStream outMeans = new FileOutputStream(new File("E:/mean.txt"));
//            OutputStream outStdev = new FileOutputStream(new File("E:/stDev.txt"));
//            Nd4j.write(outMeans,means);
//            Nd4j.write(outStdev,stv);
//            NormalizerStandardize norm1 = (NormalizerStandardize)normalizer;
//            norm1.load(new File("E:/mean.txt"),new File("E:/stDev.txt"));
//            DataNormalization normalizer2 = norm1;
           // normalizer2.transform(testData);

            UIServer uiServer = UIServer.getInstance();
            uiServer.enableRemoteListener();        //Necessary: remote support is not enabled by default

            log.info("Build model....");

            Map<Integer, Double> lrSchedule = new HashMap<>();
            lrSchedule.put(0, 0.1);
            lrSchedule.put(100, 0.01);
            lrSchedule.put(1000, 0.001);
            lrSchedule.put(3000, 0.0001);

            MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .activation(Activation.SIGMOID)
                .weightInit(WeightInit.XAVIER)
                .updater(Updater.ADAM)
                .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
                .learningRateDecayPolicy(LearningRatePolicy.Schedule)
                .learningRateSchedule(lrSchedule)
                .regularization(true).l2(1e-4)
                .list()
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(100)
                    .build())
                .layer(1, new DenseLayer.Builder().nIn(100).nOut(100)
                    .build())
                .layer(2, new DenseLayer.Builder().nIn(100).nOut(100)
                    .build())
                .layer(3, new DenseLayer.Builder().nIn(100).nOut(100)
                    .build())
                .layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR)
                    .activation(Activation.SOFTMAX)
                    .nIn(100).nOut(outputNum).build())
                .backprop(true).pretrain(false)
                .build();


            //run the model
            MultiLayerNetwork model = new MultiLayerNetwork(conf);
            model.init();
            model.setListeners(new ScoreIterationListener(10));

            StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
            model.setListeners(new StatsListener(remoteUIRouter));

            for(int i=0;i < 1000;i++){

                model.fit(trainIterator);
                if(i%50 == 0){

                    Evaluation eval = new Evaluation(8);
                    INDArray output = model.output(testData.getFeatureMatrix());
                    eval.eval(testData.getLabels(), output);
                    log.info(eval.stats());
                    ModelSerializer.writeModel(model, "E:/custom_images/colorRecognize.bin1",true);
                }
            }

            ModelSerializer.writeModel(model, "E:/custom_images/colorRecognize.bin1",true);
            //evaluate the model on the test set
            Evaluation eval = new Evaluation(8);
            INDArray output = model.output(testData.getFeatureMatrix());
            eval.eval(testData.getLabels(), output);
            log.info(eval.stats());
            System.out.println("+++++++运行结束++++++++++++");
        }
    }
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值