java深度学习之DJL创建DenseNet网络

package com.example.demo.djl;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.FashionMnist;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.util.PairList;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class DenseBlock extends AbstractBlock {

    private static final byte VERSION = 2;
    public SequentialBlock net = new SequentialBlock();

    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
        int batchSize = 256;
        float lr = 0.1f;
        int numEpochs = 10;

        double[] trainLoss;
        double[] testAccuracy;
        double[] epochCount;
        double[] trainAccuracy;

        epochCount = new double[numEpochs];

        for (int i = 0; i < epochCount.length; i++) {
            epochCount[i] = (i + 1);
        }
        Loss loss = Loss.softmaxCrossEntropyLoss();

        Tracker lrt = Tracker.fixed(lr);
        Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

        DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

        try (NDManager manager = NDManager.newBaseManager(); Model model = DenseBlock.getModel(); Trainer trainer = model.newTrainer(config);) {

///

//
//            SequentialBlock block = new SequentialBlock()
//                    .add(new DenseBlock(2, 10));
//            // new DenseBlock(x, y)    new Shape(a, b, c, d)
//            NDArray X = manager.randomUniform(0f, 1.0f, new Shape(4, 3, 8, 8));
//
//            block.setInitializer(new XavierInitializer());
//            block.initialize(manager, DataType.FLOAT32, X.getShape());
//
//            ParameterStore parameterStore = new ParameterStore(manager, true);
//
//            Shape currentShape = X.getShape();
//
//            for (int i = 0; i < block.getChildren().size(); i++) {
//
//                Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(manager, new Shape[]{X.getShape()});
//                currentShape = newShape[0];
//                // currentShape = (a,x*y+b,c,d)
//                System.out.println(currentShape);
//            }


/

            FashionMnist trainIter =
                    FashionMnist.builder()
                            .optPipeline(new Pipeline().add(new Resize(96)).add(new ToTensor()))
                            .optUsage(Dataset.Usage.TRAIN)
                            .setSampling(batchSize, true)
                            .build();

            FashionMnist testIter =
                    FashionMnist.builder()
                            .optPipeline(new Pipeline().add(new Resize(96)).add(new ToTensor()))
                            .optUsage(Dataset.Usage.TEST)
                            .setSampling(batchSize, true)
                            .build();

            trainIter.prepare();
            testIter.prepare();

            trainer.initialize(new Shape(1, 1, 96, 96));

            Map<String, double[]> evaluatorMetrics = new HashMap<>();
            double avgTrainTimePerEpoch = 0;
            trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics, avgTrainTimePerEpoch);
            trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
            trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
            testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

            System.out.printf("loss %.3f,", trainLoss[numEpochs - 1]);
            System.out.printf(" train acc %.3f,", trainAccuracy[numEpochs - 1]);
            System.out.printf(" test acc %.3f\n", testAccuracy[numEpochs - 1]);
            System.out.printf("%.1f examples/sec", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
            System.out.println();
        }


    }

    public DenseBlock(int numConvs, int numChannels) {
        super(VERSION);
        for (int i = 0; i < numConvs; i++) {
            this.net.add(
                    addChildBlock("denseBlock" + i, convBlock(numChannels))
            );
        }
    }

    @Override
    public String toString() {
        return "DenseBlock()";
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList X, boolean training, PairList<String, Object> pairList) {

        NDArray Y;
        for (Block block : this.net.getChildren().values()) {
            Y = block.forward(parameterStore, X, training).singletonOrThrow();
            X = new NDList(NDArrays.concat(new NDList(X.singletonOrThrow(), Y), 1));
        }
        return X;
    }

    @Override
    public Shape[] getOutputShapes(NDManager ndManager, Shape[] inputs) {
        Shape[] shapesX = inputs;
        for (Block block : this.net.getChildren().values()) {
            Shape[] shapesY = block.getOutputShapes(ndManager, shapesX);
            shapesX[0] = new Shape(
                    shapesX[0].get(0),
                    shapesY[0].get(1) + shapesX[0].get(1),
                    shapesX[0].get(2),
                    shapesX[0].get(3)
            );
        }
        return shapesX;
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
        Shape shapesX = inputShapes[0];

        for (Block block : this.net.getChildren().values()) {
            Shape[] shapesY = block.initialize(manager, DataType.FLOAT32, shapesX);
            shapesX = new Shape(
                    shapesX.get(0),
                    shapesY[0].get(1) + shapesX.get(1),
                    shapesX.get(2),
                    shapesX.get(3)
            );
        }
    }

    public static SequentialBlock convBlock(int numChannels) {

        SequentialBlock block = new SequentialBlock()
                .add(BatchNorm.builder().build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setFilters(numChannels)
                        .setKernelShape(new Shape(3, 3))
                        .optPadding(new Shape(1, 1))
                        .optStride(new Shape(1, 1))
                        .build()
                );
        return block;
    }

    public static SequentialBlock transitionBlock(int numChannels) {
        SequentialBlock blk = new SequentialBlock()
                .add(BatchNorm.builder().build())
                .add(Activation::relu)
                .add(
                        Conv2d.builder()
                                .setFilters(numChannels)
                                .setKernelShape(new Shape(1, 1))
                                .optStride(new Shape(1, 1))
                                .build()
                )
                .add(Pool.avgPool2dBlock(new Shape(2, 2), new Shape(2, 2)));

        return blk;
    }


    public static Model getModel() {
        Model model = Model.newInstance("DenseNet");
        SequentialBlock net = new SequentialBlock()
                .add(Conv2d.builder()
                        .setFilters(64)
                        .setKernelShape(new Shape(7, 7))
                        .optStride(new Shape(2, 2))
                        .optPadding(new Shape(3, 3))
                        .build())
                .add(BatchNorm.builder().build())
                .add(Activation::relu)
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(1, 1)));

        int numChannels = 64;
        int growthRate = 32;

        int[] numConvsInDenseBlocks = new int[]{4, 4, 4, 4};

        for (int index = 0; index < numConvsInDenseBlocks.length; index++) {

            int numConvs = numConvsInDenseBlocks[index];
            net.add(new DenseBlock(numConvs, growthRate));

            numChannels += (numConvs * growthRate);

            if (index != (numConvsInDenseBlocks.length - 1)) {
                numChannels = (numChannels / 2);
                net.add(transitionBlock(numChannels));
            }
        }

        net.add(BatchNorm.builder().build())
                .add(Activation::relu)
                .add(Pool.globalAvgPool2dBlock())
                //10 图片分类的 类别数
                .add(Linear.builder().setUnits(10).build());
        System.out.println(net);
        model.setBlock(net);
        return model;
    }

    public static void trainingChapter6(RandomAccessDataset trainIter, RandomAccessDataset testIter,
                                        int numEpochs, Trainer trainer, Map<String, double[]> evaluatorMetrics, double avgTrainTimePerEpoch) throws IOException, TranslateException {

        trainer.setMetrics(new Metrics());

        EasyTrain.fit(trainer, numEpochs, trainIter, testIter);

        Metrics metrics = trainer.getMetrics();

        trainer.getEvaluators().stream()
                .forEach(evaluator -> {
                    evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                    evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                });

        avgTrainTimePerEpoch = metrics.mean("epoch");
    }
}

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

非ban必选

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值