java深度学习之DJL创建NiN网络

1、NiN块由一个卷积层和两个1×1卷积层组成。

构建NiN块代码如下


    public static SequentialBlock niNBlock(int numChannels, Shape kernelShape,
                                           Shape strideShape, Shape paddingShape) {

        SequentialBlock tempBlock = new SequentialBlock();
        // numChannels  通道数 滤波器层数
        //kernelShape 卷积核大小
        //strideShape 步幅
        // paddingShape  填充大小
        tempBlock.add(Conv2d.builder()
                .setKernelShape(kernelShape)
                .optStride(strideShape)
                .optPadding(paddingShape)
                .setFilters(numChannels)
                .build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(1, 1))
                        .setFilters(numChannels)
                        .build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(1, 1))
                        .setFilters(numChannels)
                        .build())
                .add(Activation::relu);

        return tempBlock;
    }

2、构建模型网络代码如下

 System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
        SequentialBlock block = new SequentialBlock();
        //构建NiN网络
        block.add(niNBlock(96, new Shape(11, 11), new Shape(4, 4), new Shape(0, 0)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                .add(niNBlock(256, new Shape(5, 5), new Shape(1, 1), new Shape(2, 2)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                .add(niNBlock(384, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                //构建 Dropout 层
                .add(Dropout.builder().optRate(0.5f).build())
                // There are 10 label classes
                .add(niNBlock(10, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
                // The global average pooling layer automatically sets the window shape
                // to the height and width of the input 平均池化层
                .add(Pool.globalAvgPool2dBlock())
                // Transform the four-dimensional output into two-dimensional output
                // with a shape of (batch size, 10)
                .add(Blocks.batchFlattenBlock());
        //学习率
        float lr = 0.1f;
        Model model = Model.newInstance("cnn");
        model.setBlock(block);

后面就是准备数据 和训练

整体代码如下

package com.example.demo.djl;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.FashionMnist;
import ai.djl.basicdataset.ImageFolder;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import com.example.demo.djl.covid19.Covid19Models;
import com.example.demo.djl.covid19.Covid19Training;
import org.apache.commons.lang3.ArrayUtils;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.StringColumn;
import tech.tablesaw.api.Table;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class NiNTest {

    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
        // 设置模型存放目录
        Path modelDir = Paths.get("nin");
        //学习率
        float lr = 0.05f;

        Loss loss = Loss.softmaxCrossEntropyLoss();

        Tracker lrt = Tracker.fixed(lr);
        Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

        DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

        try (NDManager manager = NDManager.newBaseManager(); Model model = getNiNModel(); Trainer trainer = model.newTrainer(config)) {
            ///从一个均匀分布[low,high)中随机采样,注意定义域是左闭右开,即包含low,不包含high.
            //
            //参数介绍:
            //
            //    low: 采样下界,float类型,默认值为0;
            //    high: 采样上界,float类型,默认值为1;
            //    size: 输出样本数目,为int或元组(tuple)类型,例如,size=(m,n,k), 则输出m*n*k个样本,缺省时输出1个值。
            Block block = model.getBlock();

            NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 3, 224, 224));

            trainer.initialize(X.getShape());

            Shape currentShape = X.getShape();

            for (int i = 0; i < block.getChildren().size(); i++) {

                Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(manager, new Shape[]{currentShape});
                currentShape = newShape[0];
                //获取每一层的输出层的Shape
                System.out.println(block.getChildren().get(i).getKey() + " layer output : " + currentShape);
            }
            // 批训练大小
            int batchSize = 128;
            //训练次数
            int numEpochs = 30;
            //训练损失值
            double[] trainLoss;
            //测试正确率
            double[] testAccuracy;
            double[] epochCount;
            //训练正确率
            double[] trainAccuracy;

            epochCount = new double[numEpochs];

            for (int i = 0; i < epochCount.length; i++) {
                epochCount[i] = i + 1;
            }
            //FashionMnist  数据是 28*28 的灰度图 通道是1  那么第89 行应该改改为   NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 224, 224)); 单通道
//        FashionMnist trainIter = FashionMnist.builder()
//                .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
//                .optUsage(Dataset.Usage.TRAIN)
//                .setSampling(batchSize, true)
//                .build();
//
//        FashionMnist testIter = FashionMnist.builder()
//                .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
//                .optUsage(Dataset.Usage.TEST)
//                .setSampling(batchSize, true)
//                .build();
//
//        trainIter.prepare();
//        testIter.prepare();
            //使用covid-19 x-ray图片
            ImageFolder dataset = Covid19Training.initDataset("D:\\covid19dataset\\COVID-19 Radiography Database\\");
            // 设置训练数据和验证数据
            RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);

            RandomAccessDataset trainIter = datasets[0];

            RandomAccessDataset testIter = datasets[1];

            trainIter.prepare();
            testIter.prepare();


            Map<String, double[]> evaluatorMetrics = new HashMap<>();
            double avgTrainTimePerEpoch = 0;
            trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics, avgTrainTimePerEpoch);

            trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
            trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
            testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

            System.out.printf("loss %.3f,", trainLoss[numEpochs - 1]);
            System.out.printf(" train acc %.3f,", trainAccuracy[numEpochs - 1]);
            System.out.printf(" test acc %.3f\n", testAccuracy[numEpochs - 1]);
            System.out.printf("%.1f examples/sec", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
            System.out.println();

            String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

            Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
            Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
            Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                    trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");

            model.save(modelDir, "ninCovid19");

            // save labels into model directory
            Covid19Models.saveSynset(modelDir, dataset.getSynset());

            Table data = Table.create("Data").addColumns(
                    DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
                    DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),
                    StringColumn.create("lossLabel", lossLabel)
            );
            //画图

        }

    }

    public static SequentialBlock niNBlock(int numChannels, Shape kernelShape,
                                           Shape strideShape, Shape paddingShape) {

        SequentialBlock tempBlock = new SequentialBlock();
        // numChannels  通道数 滤波器层数
        //kernelShape 卷积核大小
        //strideShape 步幅
        // paddingShape  填充大小
        tempBlock.add(Conv2d.builder()
                .setKernelShape(kernelShape)
                .optStride(strideShape)
                .optPadding(paddingShape)
                .setFilters(numChannels)
                .build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(1, 1))
                        .setFilters(numChannels)
                        .build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(1, 1))
                        .setFilters(numChannels)
                        .build())
                .add(Activation::relu);

        return tempBlock;
    }


    public static void trainingChapter6(RandomAccessDataset trainIter, RandomAccessDataset testIter,
                                        int numEpochs, Trainer trainer, Map<String, double[]> evaluatorMetrics, double avgTrainTimePerEpoch) throws IOException, TranslateException {

        trainer.setMetrics(new Metrics());

        EasyTrain.fit(trainer, numEpochs, trainIter, testIter);

        Metrics metrics = trainer.getMetrics();

        trainer.getEvaluators().stream()
                .forEach(evaluator -> {
                    evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                    evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                });

        avgTrainTimePerEpoch = metrics.mean("epoch");
    }

    public static Model getNiNModel() {
        SequentialBlock block = new SequentialBlock();
        //构建NiN网络
        block.add(niNBlock(96, new Shape(11, 11), new Shape(4, 4), new Shape(0, 0)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                .add(niNBlock(256, new Shape(5, 5), new Shape(1, 1), new Shape(2, 2)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                .add(niNBlock(384, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                //构建 Dropout 层
                .add(Dropout.builder().optRate(0.5f).build())
                // There are 10 label classes
                .add(niNBlock(10, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
                // The global average pooling layer automatically sets the window shape
                // to the height and width of the input 平均池化层
                .add(Pool.globalAvgPool2dBlock())
                // Transform the four-dimensional output into two-dimensional output
                // with a shape of (batch size, 10)
                .add(Blocks.batchFlattenBlock());

        Model model = Model.newInstance("ninCovid19");
        model.setBlock(block);
        return model;
    }
}

准备训练数据的代码见   文章  JAVA深度学习框架DJL之COVID19 x-ray图片分类  里面的代码

还需要pom文件添加

       <dependency>
            <groupId>tech.tablesaw</groupId>
            <artifactId>tablesaw-jsplot</artifactId>
            <version>0.30.4</version>
        </dependency>

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

非ban必选

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值