1、快速创建项目的插件Alibaba Cloud Toolkit
2、pom添加依赖
<!-- https://mvnrepository.com/artifact/ai.djl/api -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.9.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.slf4j/slf4j-api -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.26</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>1.7.26</version>
</dependency>
<!-- https://mvnrepository.com/artifact/ai.djl.mxnet/mxnet-native-auto -->
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-native-auto</artifactId>
<version>1.7.0-backport</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-model-zoo</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId> net.java.dev.jna</groupId>
<artifactId>jna</artifactId>
<version>5.3.0</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.4</version>
</dependency>
3、由于缺少各种dll,下载工具进行修复
如果不行 那就重新安装 VC2015
http://c.biancheng.net/view/453.html
4、训练 MNIST 手写数字识别 代码
Arguments.java
package com.example.demo.util;
import ai.djl.Device;
import ai.djl.util.JsonUtils;
import com.google.gson.reflect.TypeToken;
import java.lang.reflect.Type;
import java.util.Map;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
public class Arguments {
private int epoch;
private int batchSize;
private int maxGpus;
private boolean isSymbolic;
private boolean preTrained;
private String outputDir;
private long limit;
private String modelDir;
private Map<String, String> criteria;
public Arguments(CommandLine cmd) {
if (cmd.hasOption("epoch")) {
epoch = Integer.parseInt(cmd.getOptionValue("epoch"));
} else {
epoch = 2;
}
maxGpus = Device.getGpuCount();
if (cmd.hasOption("max-gpus")) {
maxGpus = Math.min(Integer.parseInt(cmd.getOptionValue("max-gpus")), maxGpus);
}
if (cmd.hasOption("batch-size")) {
batchSize = Integer.parseInt(cmd.getOptionValue("batch-size"));
} else {
batchSize = maxGpus > 0 ? 32 * maxGpus : 32;
}
isSymbolic = cmd.hasOption("symbolic-model");
preTrained = cmd.hasOption("pre-trained");
if (cmd.hasOption("output-dir")) {
outputDir = cmd.getOptionValue("output-dir");
} else {
outputDir = "build/model";
}
if (cmd.hasOption("max-batches")) {
limit = Long.parseLong(cmd.getOptionValue("max-batches")) * batchSize;
} else {
limit = Long.MAX_VALUE;
}
if (cmd.hasOption("model-dir")) {
modelDir = cmd.getOptionValue("model-dir");
} else {
modelDir = null;
}
if (cmd.hasOption("criteria")) {
Type type = new TypeToken<Map<String, Object>>() {}.getType();
criteria = JsonUtils.GSON.fromJson(cmd.getOptionValue("criteria"), type);
}
}
public static Arguments parseArgs(String[] args) {
Options options = Arguments.getOptions();
try {
DefaultParser parser = new DefaultParser();
CommandLine cmd = parser.parse(options, args, null, false);
if (cmd.hasOption("help")) {
printHelp("./gradlew run --args='[OPTIONS]'", options);
return null;
}
return new Arguments(cmd);
} catch (ParseException e) {
printHelp("./gradlew run --args='[OPTIONS]'", options);
}
return null;
}
public static Options getOptions() {
Options options = new Options();
options.addOption(
Option.builder("h").longOpt("help").hasArg(false).desc("Print this help.").build());
options.addOption(
Option.builder("e")
.longOpt("epoch")
.hasArg()
.argName("EPOCH")
.desc("Numbers of epochs user would like to run")
.build());
options.addOption(
Option.builder("b")
.longOpt("batch-size")
.hasArg()
.argName("BATCH-SIZE")
.desc("The batch size of the training data.")
.build());
options.addOption(
Option.builder("g")
.longOpt("max-gpus")
.hasArg()
.argName("MAXGPUS")
.desc("Max number of GPUs to use for training")
.build());
options.addOption(
Option.builder("s")
.longOpt("symbolic-model")
.argName("SYMBOLIC")
.desc("Use symbolic model, use imperative model if false")
.build());
options.addOption(
Option.builder("p")
.longOpt("pre-trained")
.argName("PRE-TRAINED")
.desc("Use pre-trained weights")
.build());
options.addOption(
Option.builder("o")
.longOpt("output-dir")
.hasArg()
.argName("OUTPUT-DIR")
.desc("Use output to determine directory to save your model parameters")
.build());
options.addOption(
Option.builder("m")
.longOpt("max-batches")
.hasArg()
.argName("max-batches")
.desc(
"Limit each epoch to a fixed number of iterations to test the training script")
.build());
options.addOption(
Option.builder("d")
.longOpt("model-dir")
.hasArg()
.argName("MODEL-DIR")
.desc("pre-trained model file directory")
.build());
options.addOption(
Option.builder("r")
.longOpt("criteria")
.hasArg()
.argName("CRITERIA")
.desc("The criteria used for the model.")
.build());
return options;
}
public int getBatchSize() {
return batchSize;
}
public int getEpoch() {
return epoch;
}
public int getMaxGpus() {
return maxGpus;
}
public boolean isSymbolic() {
return isSymbolic;
}
public boolean isPreTrained() {
return preTrained;
}
public String getModelDir() {
return modelDir;
}
public String getOutputDir() {
return outputDir;
}
public long getLimit() {
return limit;
}
public Map<String, String> getCriteria() {
return criteria;
}
private static void printHelp(String msg, Options options) {
HelpFormatter formatter = new HelpFormatter();
formatter.setLeftPadding(1);
formatter.setWidth(120);
formatter.printHelp(msg, options);
}
public void setEpoch(int epoch) {
this.epoch = epoch;
}
public void setBatchSize(int batchSize) {
this.batchSize = batchSize;
}
public void setMaxGpus(int maxGpus) {
this.maxGpus = maxGpus;
}
public void setSymbolic(boolean symbolic) {
isSymbolic = symbolic;
}
public void setPreTrained(boolean preTrained) {
this.preTrained = preTrained;
}
public void setOutputDir(String outputDir) {
this.outputDir = outputDir;
}
public void setLimit(long limit) {
this.limit = limit;
}
public void setModelDir(String modelDir) {
this.modelDir = modelDir;
}
public void setCriteria(Map<String, String> criteria) {
this.criteria = criteria;
}
}
Mlp.java
package com.example.demo.djl;
import ai.djl.ndarray.NDList;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import java.util.function.Function;
/**
* Multilayer Perceptron (MLP) NeuralNetworks.
*
* <p>A multilayer perceptron (MLP) is a feedforward artificial neural network that generates a set
* of outputs from a set of inputs. An MLP is characterized by several layers of input nodes
* connected as a directed graph between the input and output layers. MLP uses backpropogation for
* training the network.
*
* <p>MLP is widely used for solving problems that require supervised learning as well as research
* into computational neuroscience and parallel distributed processing. Applications include speech
* recognition, image recognition and machine translation.
*/
public class Mlp extends SequentialBlock {
/**
* Create an MLP NeuralNetwork using RELU.
*
* @param input the size of the input vector
* @param output the size of the output vector
* @param hidden the sizes of all of the hidden layers
*/
public Mlp(int input, int output, int[] hidden) {
this(input, output, hidden, Activation::relu);
}
/**
* Create an MLP NeuralNetwork.
*
* @param input the size of the input vector
* @param output the size of the output vector
* @param hidden the sizes of all of the hidden layers
* @param activation the activation function to use
*/
public Mlp(int input, int output, int[] hidden, Function<NDList, NDList> activation) {
add(Blocks.batchFlattenBlock(input));
for (int hiddenSize : hidden) {
add(Linear.builder().setUnits(hiddenSize).build());
add(activation);
}
add(Linear.builder().setUnits(output).build());
}
}
运行代码
TrainMnist.java
package com.example.demo.djl;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.CheckpointsTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import com.example.demo.util.Arguments;
import java.io.IOException;
/**
* An example of training an image classification (MNIST) model.
*
* <p>See this <a
* href="https://github.com/awslabs/djl/blob/master/examples/docs/train_mnist_mlp.md">doc</a> for
* information about this example.
* Mnist手写数字0-9识别
*/
public final class TrainMnist {
private TrainMnist() {
}
public static void main(String[] args) throws IOException, TranslateException {
System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
TrainMnist.runExample(args);
}
public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
Arguments arguments = Arguments.parseArgs(args);
// if (arguments == null) {
// return null;
// }
arguments.setEpoch(5);
arguments.setBatchSize(64);
arguments.setMaxGpus(1);
// 手写字图片 28*28 大小 Construct neural network
//
// int input = 28 * 28; // 输入层大小
// 输出层 0 1 2 3 4 5 6 7 8 9
// int output = 10; // 输出层大小
// 隐藏层大小 new int[] {128, 64};
Block block =
new Mlp(
Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
Mnist.NUM_CLASSES,
new int[]{128, 64});
try (Model model = Model.newInstance("mlp")) {
model.setBlock(block);
// get training and validation dataset
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments);
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST, arguments);
// setup training configuration
DefaultTrainingConfig config = setupTrainingConfig(arguments);
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());
/*
* MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray.
* 1st axis is batch axis, we can use 1 for initialization.
*/
/*
* MNIST 包含 28x28 灰度图片并导入成 28 * 28 NDArray。
* 第一个维度是批大小, 在这里我们设置批大小为 1 用于初始化。
*/
Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
// initialize trainer with proper input shape
trainer.initialize(inputShape);
EasyTrain.fit(trainer, arguments.getEpoch(), trainingSet, validateSet);
return trainer.getTrainingResult();
}
}
}
private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
String outputDir = arguments.getOutputDir();
CheckpointsTrainingListener listener = new CheckpointsTrainingListener(outputDir);
listener.setSaveModelCallback(
trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.optDevices(Device.getDevices(arguments.getMaxGpus()))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
}
private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments)
throws IOException {
Mnist mnist =
Mnist.builder()
.optUsage(usage)
.setSampling(arguments.getBatchSize(), true)
.optLimit(arguments.getLimit())
.build();
mnist.prepare(new ProgressBar());
return mnist;
}
}