import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.util.ClassPathResource; import org.deeplearning4j.api.storage.StatsStorageRouter; import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.DoublesDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.examples.dataexamples.CSVExample; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.LearningRatePolicy; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.stats.StatsListener; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.SplitTestAndTrain; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.util.HashMap; import java.util.List; import java.util.Map; import org.deeplearning4j.util.DeepLearningIOUtil; /** * Created by Administrator on 2017/10/27 0027. */ public class colorRecognize { private static Logger log = LoggerFactory.getLogger(org.deeplearning4j.examples.dataexamples.CSVExample.class); public static void main(String[] args) throws Exception { // MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork("E:/custom_images/colorRecognize.bin"); //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing LineNumberReader lnr = new LineNumberReader(new FileReader(new File("E:\\DataExamples\\color1.txt"))); lnr.skip(Long.MAX_VALUE); int rows = lnr.getLineNumber() + 1; lnr.close(); int numLinesToSkip = 0; char delimiter = ','; RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter); recordReader.initialize(new FileSplit(new File("E:\\DataExamples\\color1.txt"))); //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network int labelIndex = 3; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row int numClasses = 8; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2 int batchSize = rows; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets) DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); DataSet allData = iterator.next(); allData.shuffle(); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.7); //Use 65% of data for training final int numInputs = 3; int outputNum = 8; int iterations = 2; long seed = 126; DataSet trainingData = testAndTrain.getTrain(); DataSet testData = testAndTrain.getTest(); List<DataSet> trainList = trainingData.asList(); DataSetIterator trainIterator = new ListDataSetIterator(trainList,10); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data // // normalizer.transform(trainingData); //Apply normalization to the training data // // normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set // NormalizerSerializer.getDefault().write(normalizer,"E:/mean.txt"); DataNormalization norm1 = NormalizerSerializer.getDefault().restore("E:/mean.txt"); // NormalizerStandardizeSerializer.write(normalizer, "E:/mean.txt"); NormalizerStandardize norm = (NormalizerStandardize)normalizer; // INDArray means = norm.getMean(); // INDArray stv = norm.getStd(); // OutputStream outMeans = new FileOutputStream(new File("E:/mean.txt")); // OutputStream outStdev = new FileOutputStream(new File("E:/stDev.txt")); // Nd4j.write(outMeans,means); // Nd4j.write(outStdev,stv); // NormalizerStandardize norm1 = (NormalizerStandardize)normalizer; // norm1.load(new File("E:/mean.txt"),new File("E:/stDev.txt")); // DataNormalization normalizer2 = norm1; // normalizer2.transform(testData); UIServer uiServer = UIServer.getInstance(); uiServer.enableRemoteListener(); //Necessary: remote support is not enabled by default log.info("Build model...."); Map<Integer, Double> lrSchedule = new HashMap<>(); lrSchedule.put(0, 0.1); lrSchedule.put(100, 0.01); lrSchedule.put(1000, 0.001); lrSchedule.put(3000, 0.0001); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(iterations) .activation(Activation.SIGMOID) .weightInit(WeightInit.XAVIER) .updater(Updater.ADAM) .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) .learningRateDecayPolicy(LearningRatePolicy.Schedule) .learningRateSchedule(lrSchedule) .regularization(true).l2(1e-4) .list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(100) .build()) .layer(1, new DenseLayer.Builder().nIn(100).nOut(100) .build()) .layer(2, new DenseLayer.Builder().nIn(100).nOut(100) .build()) .layer(3, new DenseLayer.Builder().nIn(100).nOut(100) .build()) .layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR) .activation(Activation.SOFTMAX) .nIn(100).nOut(outputNum).build()) .backprop(true).pretrain(false) .build(); //run the model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(10)); StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://localhost:9000"); model.setListeners(new StatsListener(remoteUIRouter)); for(int i=0;i < 1000;i++){ model.fit(trainIterator); if(i%50 == 0){ Evaluation eval = new Evaluation(8); INDArray output = model.output(testData.getFeatureMatrix()); eval.eval(testData.getLabels(), output); log.info(eval.stats()); ModelSerializer.writeModel(model, "E:/custom_images/colorRecognize.bin1",true); } } ModelSerializer.writeModel(model, "E:/custom_images/colorRecognize.bin1",true); //evaluate the model on the test set Evaluation eval = new Evaluation(8); INDArray output = model.output(testData.getFeatureMatrix()); eval.eval(testData.getLabels(), output); log.info(eval.stats()); System.out.println("+++++++运行结束++++++++++++"); } }
利用dl4j识别图像颜色
最新推荐文章于 2024-07-05 16:24:28 发布