JAVA深度学习框架DJL之Mnist手写数字0-9识别

1、快速创建项目的插件Alibaba Cloud Toolkit

2、pom添加依赖

   <!-- https://mvnrepository.com/artifact/ai.djl/api -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.9.0</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/org.slf4j/slf4j-api -->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>1.7.26</version>
        </dependency>

        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-simple</artifactId>
            <version>1.7.26</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/ai.djl.mxnet/mxnet-native-auto -->
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-native-auto</artifactId>
            <version>1.7.0-backport</version>
        </dependency>

       <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
            <version>0.9.0</version>
        </dependency>


        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.9.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.9.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-engine</artifactId>
            <version>0.9.0</version>
        </dependency>

        <dependency>
            <groupId> net.java.dev.jna</groupId>
            <artifactId>jna</artifactId>
            <version>5.3.0</version>
        </dependency>
        <dependency>
            <groupId>commons-cli</groupId>
            <artifactId>commons-cli</artifactId>
            <version>1.4</version>
        </dependency>

3、由于缺少各种dll,下载工具进行修复

lanzous.com - lanzous 资源和信息。

如果不行 那就重新安装 VC2015

http://c.biancheng.net/view/453.html

4、训练 MNIST 手写数字识别 代码 

Arguments.java
package com.example.demo.util;

import ai.djl.Device;
import ai.djl.util.JsonUtils;
import com.google.gson.reflect.TypeToken;
import java.lang.reflect.Type;
import java.util.Map;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;

public class Arguments {

    private int epoch;
    private int batchSize;
    private int maxGpus;
    private boolean isSymbolic;
    private boolean preTrained;
    private String outputDir;
    private long limit;
    private String modelDir;
    private Map<String, String> criteria;

    public Arguments(CommandLine cmd) {
        if (cmd.hasOption("epoch")) {
            epoch = Integer.parseInt(cmd.getOptionValue("epoch"));
        } else {
            epoch = 2;
        }
        maxGpus = Device.getGpuCount();
        if (cmd.hasOption("max-gpus")) {
            maxGpus = Math.min(Integer.parseInt(cmd.getOptionValue("max-gpus")), maxGpus);
        }
        if (cmd.hasOption("batch-size")) {
            batchSize = Integer.parseInt(cmd.getOptionValue("batch-size"));
        } else {
            batchSize = maxGpus > 0 ? 32 * maxGpus : 32;
        }
        isSymbolic = cmd.hasOption("symbolic-model");
        preTrained = cmd.hasOption("pre-trained");

        if (cmd.hasOption("output-dir")) {
            outputDir = cmd.getOptionValue("output-dir");
        } else {
            outputDir = "build/model";
        }
        if (cmd.hasOption("max-batches")) {
            limit = Long.parseLong(cmd.getOptionValue("max-batches")) * batchSize;
        } else {
            limit = Long.MAX_VALUE;
        }
        if (cmd.hasOption("model-dir")) {
            modelDir = cmd.getOptionValue("model-dir");
        } else {
            modelDir = null;
        }
        if (cmd.hasOption("criteria")) {
            Type type = new TypeToken<Map<String, Object>>() {}.getType();
            criteria = JsonUtils.GSON.fromJson(cmd.getOptionValue("criteria"), type);
        }
    }

    public static Arguments parseArgs(String[] args) {
        Options options = Arguments.getOptions();
        try {
            DefaultParser parser = new DefaultParser();
            CommandLine cmd = parser.parse(options, args, null, false);
            if (cmd.hasOption("help")) {
                printHelp("./gradlew run --args='[OPTIONS]'", options);
                return null;
            }
            return new Arguments(cmd);
        } catch (ParseException e) {
            printHelp("./gradlew run --args='[OPTIONS]'", options);
        }
        return null;
    }

    public static Options getOptions() {
        Options options = new Options();
        options.addOption(
                Option.builder("h").longOpt("help").hasArg(false).desc("Print this help.").build());
        options.addOption(
                Option.builder("e")
                        .longOpt("epoch")
                        .hasArg()
                        .argName("EPOCH")
                        .desc("Numbers of epochs user would like to run")
                        .build());
        options.addOption(
                Option.builder("b")
                        .longOpt("batch-size")
                        .hasArg()
                        .argName("BATCH-SIZE")
                        .desc("The batch size of the training data.")
                        .build());
        options.addOption(
                Option.builder("g")
                        .longOpt("max-gpus")
                        .hasArg()
                        .argName("MAXGPUS")
                        .desc("Max number of GPUs to use for training")
                        .build());
        options.addOption(
                Option.builder("s")
                        .longOpt("symbolic-model")
                        .argName("SYMBOLIC")
                        .desc("Use symbolic model, use imperative model if false")
                        .build());
        options.addOption(
                Option.builder("p")
                        .longOpt("pre-trained")
                        .argName("PRE-TRAINED")
                        .desc("Use pre-trained weights")
                        .build());
        options.addOption(
                Option.builder("o")
                        .longOpt("output-dir")
                        .hasArg()
                        .argName("OUTPUT-DIR")
                        .desc("Use output to determine directory to save your model parameters")
                        .build());
        options.addOption(
                Option.builder("m")
                        .longOpt("max-batches")
                        .hasArg()
                        .argName("max-batches")
                        .desc(
                                "Limit each epoch to a fixed number of iterations to test the training script")
                        .build());
        options.addOption(
                Option.builder("d")
                        .longOpt("model-dir")
                        .hasArg()
                        .argName("MODEL-DIR")
                        .desc("pre-trained model file directory")
                        .build());
        options.addOption(
                Option.builder("r")
                        .longOpt("criteria")
                        .hasArg()
                        .argName("CRITERIA")
                        .desc("The criteria used for the model.")
                        .build());
        return options;
    }

    public int getBatchSize() {
        return batchSize;
    }

    public int getEpoch() {
        return epoch;
    }

    public int getMaxGpus() {
        return maxGpus;
    }

    public boolean isSymbolic() {
        return isSymbolic;
    }

    public boolean isPreTrained() {
        return preTrained;
    }

    public String getModelDir() {
        return modelDir;
    }

    public String getOutputDir() {
        return outputDir;
    }

    public long getLimit() {
        return limit;
    }

    public Map<String, String> getCriteria() {
        return criteria;
    }

    private static void printHelp(String msg, Options options) {
        HelpFormatter formatter = new HelpFormatter();
        formatter.setLeftPadding(1);
        formatter.setWidth(120);
        formatter.printHelp(msg, options);
    }

    public void setEpoch(int epoch) {
        this.epoch = epoch;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public void setMaxGpus(int maxGpus) {
        this.maxGpus = maxGpus;
    }

    public void setSymbolic(boolean symbolic) {
        isSymbolic = symbolic;
    }

    public void setPreTrained(boolean preTrained) {
        this.preTrained = preTrained;
    }

    public void setOutputDir(String outputDir) {
        this.outputDir = outputDir;
    }

    public void setLimit(long limit) {
        this.limit = limit;
    }

    public void setModelDir(String modelDir) {
        this.modelDir = modelDir;
    }

    public void setCriteria(Map<String, String> criteria) {
        this.criteria = criteria;
    }
}
Mlp.java
package com.example.demo.djl;

import ai.djl.ndarray.NDList;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import java.util.function.Function;

/**
 * Multilayer Perceptron (MLP) NeuralNetworks.
 *
 * <p>A multilayer perceptron (MLP) is a feedforward artificial neural network that generates a set
 * of outputs from a set of inputs. An MLP is characterized by several layers of input nodes
 * connected as a directed graph between the input and output layers. MLP uses backpropogation for
 * training the network.
 *
 * <p>MLP is widely used for solving problems that require supervised learning as well as research
 * into computational neuroscience and parallel distributed processing. Applications include speech
 * recognition, image recognition and machine translation.
 */
public class Mlp extends SequentialBlock {

    /**
     * Create an MLP NeuralNetwork using RELU.
     *
     * @param input the size of the input vector
     * @param output the size of the output vector
     * @param hidden the sizes of all of the hidden layers
     */
    public Mlp(int input, int output, int[] hidden) {
        this(input, output, hidden, Activation::relu);
    }

    /**
     * Create an MLP NeuralNetwork.
     *
     * @param input the size of the input vector
     * @param output the size of the output vector
     * @param hidden the sizes of all of the hidden layers
     * @param activation the activation function to use
     */
    public Mlp(int input, int output, int[] hidden, Function<NDList, NDList> activation) {
        add(Blocks.batchFlattenBlock(input));
        for (int hiddenSize : hidden) {
            add(Linear.builder().setUnits(hiddenSize).build());
            add(activation);
        }

        add(Linear.builder().setUnits(output).build());
    }
}

运行代码

TrainMnist.java
package com.example.demo.djl;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.CheckpointsTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import com.example.demo.util.Arguments;

import java.io.IOException;

/**
 * An example of training an image classification (MNIST) model.
 *
 * <p>See this <a
 * href="https://github.com/awslabs/djl/blob/master/examples/docs/train_mnist_mlp.md">doc</a> for
 * information about this example.
 * Mnist手写数字0-9识别
 */
public final class TrainMnist {

    private TrainMnist() {
    }

    public static void main(String[] args) throws IOException, TranslateException {
        System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
        TrainMnist.runExample(args);
    }

    public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
        Arguments arguments = Arguments.parseArgs(args);
//        if (arguments == null) {
//            return null;
//        }
        arguments.setEpoch(5);
        arguments.setBatchSize(64);
        arguments.setMaxGpus(1);


        //  手写字图片 28*28 大小 Construct neural network
//
//        int input = 28 * 28; // 输入层大小
//        输出层 0 1 2 3 4 5 6 7 8 9
//        int output = 10; // 输出层大小
//        隐藏层大小  new int[] {128, 64};
        Block block =
                new Mlp(
                        Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
                        Mnist.NUM_CLASSES,
                        new int[]{128, 64});

        try (Model model = Model.newInstance("mlp")) {
            model.setBlock(block);

            // get training and validation dataset
            RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments);
            RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST, arguments);

            // setup training configuration
            DefaultTrainingConfig config = setupTrainingConfig(arguments);

            try (Trainer trainer = model.newTrainer(config)) {
                trainer.setMetrics(new Metrics());

                /*
                 * MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray.
                 * 1st axis is batch axis, we can use 1 for initialization.
                 */

                /*
                 * MNIST 包含 28x28 灰度图片并导入成 28 * 28 NDArray。
                 * 第一个维度是批大小, 在这里我们设置批大小为 1 用于初始化。
                 */
                Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);

                // initialize trainer with proper input shape
                trainer.initialize(inputShape);

                EasyTrain.fit(trainer, arguments.getEpoch(), trainingSet, validateSet);

                return trainer.getTrainingResult();
            }
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        String outputDir = arguments.getOutputDir();
        CheckpointsTrainingListener listener = new CheckpointsTrainingListener(outputDir);
        listener.setSaveModelCallback(
                trainer -> {
                    TrainingResult result = trainer.getTrainingResult();
                    Model model = trainer.getModel();
                    float accuracy = result.getValidateEvaluation("Accuracy");
                    model.setProperty("Accuracy", String.format("%.5f", accuracy));
                    model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
                });
        return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy())
                .optDevices(Device.getDevices(arguments.getMaxGpus()))
                .addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
                .addTrainingListeners(listener);
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments)
            throws IOException {
        Mnist mnist =
                Mnist.builder()
                        .optUsage(usage)
                        .setSampling(arguments.getBatchSize(), true)
                        .optLimit(arguments.getLimit())
                        .build();
        mnist.prepare(new ProgressBar());
        return mnist;
    }
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

非ban必选

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值