步骤1:加载手写数字图像
/**
* 步骤1:加载手写数字图像
*/
Image img = null;
try {
img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
} catch (IOException e) {
e.printStackTrace();
}
img.getWrappedImage();
步骤2:加载模型
/**
* 步骤2:加载模型
*/
Path modelDir = Paths.get("D:\\models");
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
try {
model.load(modelDir);
} catch (IOException e) {
e.printStackTrace();
} catch (MalformedModelException e) {
e.printStackTrace();
}
步骤3:创建翻译器
/**
* 步骤3:创建翻译器
*/
Translator<Image, Classifications> translator = new Translator<Image, Classifications>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
// 将图像转换为NDArray
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
return new NDList(NDImageUtils.toTensor(array));
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
// 创建具有输出概率的分类
NDArray probabilities = list.singletonOrThrow().softmax(0);
List<String> classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
return new Classifications(classNames, probabilities);
}
@Override
public Batchifier getBatchifier() {
// 批次组合
return Batchifier.STACK;
}
};
步骤4:创建预测
/**
* 步骤4:创建预测
* 预测器不是线程安全的
*/
Predictor<Image, Classifications> predictor = model.newPredictor(translator);
步骤5:运行推理
/**
* 步骤5:运行推理
*/
Classifications classifications = null;
try {
classifications = predictor.predict(img);
} catch (TranslateException e) {
e.printStackTrace();
}
System.out.println(classifications.toJson());
整合后
package com.lihao;
import java.awt.image.*;
import java.io.IOException;
import java.nio.file.*;
import java.util.*;
import java.util.stream.*;
import ai.djl.*;
import ai.djl.basicmodelzoo.basic.*;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.translate.*;
/**
* 模型推理
* %maven ai.djl:api:0.23.0
* %maven ai.djl:model-zoo:0.23.0
* %maven ai.djl.mxnet:mxnet-engine:0.23.0
* %maven ai.djl.mxnet:mxnet-model-zoo:0.23.0
*/
public class DjlInferenceModel {
public static void main(String[] args) {
/**
* 步骤1:加载手写数字图像
*/
Image img = null;
try {
img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
} catch (IOException e) {
e.printStackTrace();
}
img.getWrappedImage();
/**
* 步骤2:加载模型
*/
Path modelDir = Paths.get("D:\\models");
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
try {
model.load(modelDir);
} catch (IOException e) {
e.printStackTrace();
} catch (MalformedModelException e) {
e.printStackTrace();
}
/**
* 步骤3:创建翻译器
*/
Translator<Image, Classifications> translator = new Translator<Image, Classifications>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
// 将图像转换为NDArray
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
return new NDList(NDImageUtils.toTensor(array));
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
// 创建具有输出概率的分类
NDArray probabilities = list.singletonOrThrow().softmax(0);
List<String> classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
return new Classifications(classNames, probabilities);
}
@Override
public Batchifier getBatchifier() {
// 批次组合
return Batchifier.STACK;
}
};
/**
* 步骤4:创建预测
* 预测器不是线程安全的
*/
Predictor<Image, Classifications> predictor = model.newPredictor(translator);
/**
* 步骤5:运行推理
*/
Classifications classifications = null;
try {
classifications = predictor.predict(img);
} catch (TranslateException e) {
e.printStackTrace();
}
System.out.println(classifications.toJson());
}
}
POM文件
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.lihao</groupId>
<artifactId>djl</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>Spring Boot Blank Project (from https://github.com/making/spring-boot-blank)</name>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.7.12</version>
</parent>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<start-class>com.lihao.App</start-class>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-thymeleaf</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
<version>0.23.0</version>
</dependency>
</dependencies>
<build>
<finalName>djl</finalName>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<version>2.6.0</version>
</plugin>
</plugins>
</build>
</project>
运行结果