java深度学习之DJL创建AlexNet

package com.example.demo.djl;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.FashionMnist;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class AlexNetTest {
    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");

        float lr = 0.01f;
        Loss loss = Loss.softmaxCrossEntropyLoss();

        Tracker lrt = Tracker.fixed(lr);
        Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

        DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

        try (NDManager manager = NDManager.newBaseManager(); Model model = AlexNetTest.getModel(); Trainer trainer = model.newTrainer(config)) {
            SequentialBlock block = (SequentialBlock) model.getBlock();
            NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 224, 224));
            trainer.initialize(X.getShape());

            Shape currentShape = X.getShape();

            for (int i = 0; i < block.getChildren().size(); i++) {

                Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(manager, new Shape[]{currentShape});
                currentShape = newShape[0];
                System.out.println(block.getChildren().get(i).getKey() + " layer output : " + currentShape);
            }

            //准备数据

            int batchSize = 128;
            int numEpochs = 10;
            double[] trainLoss;
            double[] testAccuracy;
            double[] epochCount;
            double[] trainAccuracy;

            epochCount = new double[numEpochs];

            for (int i = 0; i < epochCount.length; i++) {
                epochCount[i] = (i + 1);
            }

            FashionMnist trainIter = FashionMnist.builder()
                    .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
                    .optUsage(Dataset.Usage.TRAIN)
                    .setSampling(batchSize, true)
                    .build();

            FashionMnist testIter = FashionMnist.builder()
                    .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
                    .optUsage(Dataset.Usage.TEST)
                    .setSampling(batchSize, true)
                    .build();

            trainIter.prepare();
            testIter.prepare();

            Map<String, double[]> evaluatorMetrics = new HashMap<>();
            double avgTrainTimePerEpoch = 0;
            trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics, avgTrainTimePerEpoch);

            trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
            trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
            testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

            System.out.printf("loss %.3f,", trainLoss[numEpochs - 1]);
            System.out.printf(" train acc %.3f,", trainAccuracy[numEpochs - 1]);
            System.out.printf(" test acc %.3f\n", testAccuracy[numEpochs - 1]);
            System.out.printf("%.1f examples/sec", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
            System.out.println();
        }


    }

    public static Model getModel() {
        Model model = Model.newInstance("AlexNet");

        // Here, we use a larger 11 x 11 window to capture objects. At the same time,
// we use a stride of 4 to greatly reduce the height and width of the output.
//Here, the number of output channels is much larger than that in LeNet
        SequentialBlock block = new SequentialBlock();
        block
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(11, 11))
                        .optStride(new Shape(4, 4))
                        .setFilters(96).build())
                .add(Activation::relu)
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                // Make the convolution window smaller, set padding to 2 for consistent
                // height and width across the input and output, and increase the
                // number of output channels
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(5, 5))
                        .optPadding(new Shape(2, 2))
                        .setFilters(256).build())
                .add(Activation::relu)
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                // Use three successive convolutional layers and a smaller convolution
                // window. Except for the final convolutional layer, the number of
                // output channels is further increased. Pooling layers are not used to
                // reduce the height and width of input after the first two
                // convolutional layers
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(3, 3))
                        .optPadding(new Shape(1, 1))
                        .setFilters(384).build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(3, 3))
                        .optPadding(new Shape(1, 1))
                        .setFilters(384).build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(3, 3))
                        .optPadding(new Shape(1, 1))
                        .setFilters(256).build())
                .add(Activation::relu)
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                // Here, the number of outputs of the fully connected layer is several
                // times larger than that in LeNet. Use the dropout layer to mitigate
                // overfitting
                .add(Blocks.batchFlattenBlock())
                .add(Linear
                        .builder()
                        .setUnits(4096)
                        .build())
                .add(Activation::relu)
                .add(Dropout
                        .builder()
                        .optRate(0.5f)
                        .build())
                .add(Linear
                        .builder()
                        .setUnits(4096)
                        .build())
                .add(Activation::relu)
                .add(Dropout
                        .builder()
                        .optRate(0.5f)
                        .build())
                // Output layer. Since we are using Fashion-MNIST, the number of
                // classes is 10, instead of 1000 as in the paper
                .add(Linear.builder().setUnits(10).build());
        System.out.println(block);
        model.setBlock(block);
        return model;
    }

    public static void trainingChapter6(RandomAccessDataset trainIter, RandomAccessDataset testIter,
                                        int numEpochs, Trainer trainer, Map<String, double[]> evaluatorMetrics, double avgTrainTimePerEpoch) throws IOException, TranslateException {

        trainer.setMetrics(new Metrics());

        EasyTrain.fit(trainer, numEpochs, trainIter, testIter);

        Metrics metrics = trainer.getMetrics();

        trainer.getEvaluators().stream()
                .forEach(evaluator -> {
                    evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                    evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                });

        avgTrainTimePerEpoch = metrics.mean("epoch");
    }
}

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

非ban必选

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值