我觉得GPT老师的语言功底比大多数的博客主要好(包括我自己),阅读起来更易理解,而且哪里不明白还可以直接问gpt老师,孜孜不倦,尽心尽责,全天待命,究极贴心。有这么厉害的一个老师,不学习简直暴殄天物。
于是乎我准备立一个flag,挑战跟着GPT老师学习365天,每天我都会整理自己的学习心得和脉络(文字大部分都是GPT直接生成的,我觉得比我自己写肯定好多了)感谢gpt老师!跪谢
全系列文章:跟着GPT学习-AI系列
CIFAR数据集训练的resnet模型使用了下面网站提供的onnx模型:
https://github.com/Lucas110550/CIFAR100_TinyImageNet_ResNet
最近临时接到一个任务,基于JAVA集成onnx模型,并进行推理,集成的整个过程非常痛苦,JAVA DJL库源代码中有死循环的错误,且各种算子不支持。经过各种改造,最后总算跑通了整个流程,发出来给需要的你们参考下吧~~
ImageClassificationController
详解文档
概述
ImageClassificationController
是一个基于 Spring Boot 的 REST 控制器,用于处理图像分类任务。它使用预训练的 ONNX 模型,并通过 DJL(Deep Java Library)加载模型、处理图像和进行预测。这个控制器提供了两个主要的功能端点:一个用于对单个图像进行分类,另一个用于测试整个数据集的模型准确性。
主要功能
- 分类单个图像:提供一个端点,用户可以通过该端点上传图像路径以进行分类。
- 测试所有图像:提供一个端点,测试模型在整个测试数据集上的准确性。
代码结构
1. classify
方法
通过 /classify
GET 请求进行单个图像分类。
- 步骤:
- 加载 synset 文件,这个文件包含模型的所有分类标签。
- 设置翻译器
CIFAR100Translator
,用于处理图像的输入和输出。 - 设置模型的加载标准,包括模型路径、翻译器和进度条。
- 加载模型并创建预测器。
- 读取图像文件并转换为
Image
对象。 - 使用预测器进行分类,并返回分类结果。
2. testAll
方法
通过 /testAll
GET 请求对整个数据集进行测试。
- 步骤:
- 加载 synset 文件。
- 设置翻译器
CIFAR100Translator
。 - 设置模型的加载标准。
- 加载模型并创建预测器。
- 读取测试数据集目录,遍历每个分类文件夹和其中的图像文件。
- 对每张图像进行预测,并统计预测准确性。
- 计算并返回模型在测试集上的准确率。
CIFAR100Translator
类
用于处理图像输入和分类输出。
-
processInput 方法:
- 将
Image
对象转换为NDArray
。 - 手动进行归一化并重新排列数据,因为
NDArray
无法直接使用div
、exp
和softmax
等方法。 - 使用 CIFAR-100 数据集的均值和标准差进行归一化。
- 创建归一化后的
NDArray
并返回。
- 将
-
processOutput 方法:
- 获取模型输出的
NDArray
。 - 手动计算每个类的概率,包括最大值、指数和归一化,因为
NDArray
无法直接使用exp
和softmax
等方法。 - 将输出转换为分类结果,包括类别名称和概率。
- 获取模型输出的
详细步骤说明
1. 加载 synset 文件
loadSynset
方法读取包含分类标签的文件,并将其存储在列表中。
private List<String> loadSynset(String synsetPath) throws IOException {
List<String> synset = new ArrayList<>();
try (BufferedReader br = new BufferedReader(new FileReader(synsetPath))) {
String line;
while ((line = br.readLine()) != null) {
synset.add(line);
}
}
return synset;
}
2. 手动归一化与重新排列数据
在 processInput
方法中,手动进行归一化,因为 NDArray
无法直接使用相关方法。
// 手动归一化并重新排列数据
int height = 32;
int width = 32;
int channels = 3;
byte[] rawData = array.toByteArray();
float[] data = new float[height * width * channels];
float[] mean = {0.5071f, 0.4865f, 0.4409f};
float[] std = {0.2673f, 0.2564f, 0.2761f};
// 归一化过程
for (int c = 0; c < channels; c++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
int index = c * height * width + h * width + w;
int rawIndex = (h * width + w) * channels + c;
data[index] = ((rawData[rawIndex] & 0xFF) / 255.0f - mean[c]) / std[c];
}
}
}
3. 手动计算最大值、指数和归一化
在 processOutput
方法中,手动计算输出的概率分布。
// 手动计算最大值、指数和归一化
float[] maxValues = new float[batchSize];
float[] expData = new float[outputData.length];
float[] sumExpValues = new float[batchSize];
for (int i = 0; i < batchSize; i++) {
float maxVal = Float.NEGATIVE_INFINITY;
for (int j = 0; j < numClasses; j++) {
maxVal = Math.max(maxVal, outputData[i * numClasses + j]);
}
maxValues[i] = maxVal;
}
for (int i = 0; i < batchSize; i++) {
float sumExp = 0.0f;
for (int j = 0; j < numClasses; j++) {
int index = i * numClasses + j;
expData[index] = (float) Math.exp(outputData[index] - maxValues[i]);
sumExp += expData[index];
}
sumExpValues[i] = sumExp;
}
float[] probabilitiesData = new float[outputData.length];
for (int i = 0; i < batchSize; i++) {
for (int j = 0; j < numClasses; j++) {
int index = i * numClasses + j;
probabilitiesData[index] = expData[index] / sumExpValues[i];
}
}
内存使用情况监控
通过 printMemoryUsage
方法打印 Java 虚拟机中的内存使用情况,以便在调试和优化时使用。
private void printMemoryUsage(String phase) {
int mb = 1024 * 1024;
Runtime runtime = Runtime.getRuntime();
System.out.println("##### 内存使用情况 (" + phase + ") #####");
System.out.println("最大内存: " + runtime.maxMemory() / mb + " MB");
System.out.println("已分配内存: " + runtime.totalMemory() / mb + " MB");
System.out.println("已分配内存中的剩余空间: " + runtime.freeMemory() / mb + " MB");
System.out.println("实际使用的内存: " + (runtime.totalMemory() - runtime.freeMemory()) / mb + " MB");
}
总结
本控制器提供了高效的图像分类功能,尽管 NDArray
无法直接使用 div
、exp
和 softmax
等方法,但通过手写相关代码实现了这些功能。通过详细的步骤介绍和代码示例,能够清晰地理解整个图像分类过程。
完整代码
package com.example.demo.controller;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.*;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import ai.djl.ndarray.types.Shape;
import java.io.*;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
@RestController
public class ImageClassificationController {
public static final String ONNX_PATH = "C:/Users/jike4/Downloads/CIFAR100_resnet_large.onnx";
// public static final String ONNX_PATH = "C:/Users/jike4/Downloads/resnet18-v1-7.onnx";
@GetMapping("/classify")
public String classify(@RequestParam String imagePath) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
System.out.println("Loading synset...");
List<String> synset = loadSynset("C:/Users/jike4/synset.txt");
System.out.println("Synset loaded: " + synset.size() + " classes");
Translator<Image, Classifications> translator = new CIFAR100Translator(synset);
System.out.println("Setting up criteria...");
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("C:/Users/jike4/Downloads/CIFAR100_resnet_large.onnx"))
.optTranslator(translator)
.optProgress(new ProgressBar())
.optEngine("OnnxRuntime")
.build();
System.out.println("Loading model...");
try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
Predictor<Image, Classifications> predictor = model.newPredictor()) {
System.out.println("Model loaded, loading image...");
File imgFile = new File(imagePath);
System.out.println(imgFile.getPath());
Image img = ImageFactory.getInstance().fromInputStream(new FileInputStream(imgFile));
System.out.println("Image loaded, running prediction...");
Classifications classifications = predictor.predict(img);
System.out.println("Prediction complete");
return classifications.toString();
}
}
@GetMapping("/testAll")
public String testAll() throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
System.out.println("Loading synset...");
List<String> synset = loadSynset("C:/Users/jike4/synset.txt");
System.out.println("Synset loaded: " + synset.size() + " classes");
Translator<Image, Classifications> translator = new CIFAR100Translator(synset);
System.out.println("Setting up criteria...");
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("C:/Users/jike4/Downloads/CIFAR100_resnet_large.onnx"))
.optTranslator(translator)
.optProgress(new ProgressBar())
.optEngine("OnnxRuntime")
.build();
System.out.println("Loading model...");
try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
Predictor<Image, Classifications> predictor = model.newPredictor()) {
System.out.println("Model loaded, loading test dataset...");
File testDir = new File("C:/Users/Public/cifar100_test_images");
File[] labelDirs = testDir.listFiles(File::isDirectory);
int correctCount = 0;
int totalCount = 0;
for (File labelDir : labelDirs) {
String trueLabelName = labelDir.getName();
int trueLabel = synset.indexOf(trueLabelName);
File[] imageFiles = labelDir.listFiles((dir, name) -> name.toLowerCase().endsWith(".png"));
for (File imageFile : imageFiles) {
Image img = ImageFactory.getInstance().fromFile(imageFile.toPath());
Classifications classifications = predictor.predict(img);
int predictedLabel = synset.indexOf(classifications.best().getClassName());
if (predictedLabel == trueLabel) {
correctCount++;
}
totalCount++;
}
}
double accuracy = (double) correctCount / totalCount;
System.out.println("Test accuracy: " + accuracy);
return "Test accuracy: " + accuracy;
}
}
private List<String> loadSynset(String synsetPath) throws IOException {
List<String> synset = new ArrayList<>();
try (BufferedReader br = new BufferedReader(new FileReader(synsetPath))) {
String line;
while ((line = br.readLine()) != null) {
synset.add(line);
}
}
return synset;
}
}
class CIFAR100Translator implements NoBatchifyTranslator<Image, Classifications> {
private List<String> synset;
public CIFAR100Translator(List<String> synset) {
this.synset = synset;
}
@Override
public NDList processInput(TranslatorContext ctx, Image input) throws IOException {
// Convert Image to NDArray directly
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
// // Print NDArray info for debugging
// System.out.println("NDArray Shape: " + array.getShape());
// System.out.println("NDArray DataType: " + array.getDataType());
// 手动归一化并重新排列数据
int height = 32;
int width = 32;
int channels = 3;
// 获取图像数据为字节数组
byte[] rawData = array.toByteArray();
// 创建浮点数组用于存储归一化后的数据
float[] data = new float[height * width * channels];
// 使用CIFAR-100的均值和标准差进行归一化
float[] mean = {0.5071f, 0.4865f, 0.4409f};
float[] std = {0.2673f, 0.2564f, 0.2761f};
// 归一化过程
for (int c = 0; c < channels; c++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
int index = c * height * width + h * width + w;
int rawIndex = (h * width + w) * channels + c;
data[index] = ((rawData[rawIndex] & 0xFF) / 255.0f - mean[c]) / std[c];
}
}
}
NDArray tensor = ctx.getNDManager().create(data, new Shape(1, 3, 32, 32));
// Print reshaped NDArray info
// System.out.println("Reshaped NDArray Shape: " + tensor.getShape());
// System.out.println("Reshaped NDArray DataType: " + tensor.getDataType());
return new NDList(tensor);
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
// 获取 NDArray
NDArray outputArray = list.singletonOrThrow();
// 转换为 float 数组
float[] outputData = outputArray.toFloatArray();
int batchSize = (int) outputArray.getShape().get(0);
int numClasses = (int) outputArray.getShape().get(1);
// 手动计算最大值、指数和归一化
float[] maxValues = new float[batchSize];
float[] expData = new float[outputData.length];
float[] sumExpValues = new float[batchSize];
for (int i = 0; i < batchSize; i++) {
float maxVal = Float.NEGATIVE_INFINITY;
for (int j = 0; j < numClasses; j++) {
maxVal = Math.max(maxVal, outputData[i * numClasses + j]);
}
maxValues[i] = maxVal;
}
for (int i = 0; i < batchSize; i++) {
float sumExp = 0.0f;
for (int j = 0; j < numClasses; j++) {
int index = i * numClasses + j;
expData[index] = (float) Math.exp(outputData[index] - maxValues[i]);
sumExp += expData[index];
}
sumExpValues[i] = sumExp;
}
float[] probabilitiesData = new float[outputData.length];
for (int i = 0; i < batchSize; i++) {
for (int j = 0; j < numClasses; j++) {
int index = i * numClasses + j;
probabilitiesData[index] = expData[index] / sumExpValues[i];
}
}
// 将输出转换为分类结果
List<String> classNames = new ArrayList<>();
List<Double> probs = new ArrayList<>();
for (int i = 0; i < numClasses; i++) {
classNames.add(synset.get(i));
probs.add((double) probabilitiesData[i]);
}
return new Classifications(classNames, probs);
}
private void printMemoryUsage(String phase) {
int mb = 1024 * 1024;
// 获取 Java 虚拟机中的内存使用情况
Runtime runtime = Runtime.getRuntime();
// 打印内存使用情况
System.out.println("##### 内存使用情况 (" + phase + ") #####");
System.out.println("最大内存: " + runtime.maxMemory() / mb + " MB");
System.out.println("已分配内存: " + runtime.totalMemory() / mb + " MB");
System.out.println("已分配内存中的剩余空间: " + runtime.freeMemory() / mb + " MB");
System.out.println("实际使用的内存: " + (runtime.totalMemory() - runtime.freeMemory()) / mb + " MB");
}
}
pom依赖
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.3.0.RELEASE</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.qiyuan</groupId>
<artifactId>demo</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>demo</name>
<description>demo</description>
<url/>
<licenses>
<license/>
</licenses>
<developers>
<developer/>
</developers>
<scm>
<connection/>
<developerConnection/>
<tag/>
<url/>
</scm>
<properties>
<java.version>11</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- DJL API -->
<!-- DJL API -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.28.0</version>
</dependency>
<!-- DJL ONNX Runtime Engine -->
<dependency>
<groupId>ai.djl.onnxruntime</groupId>
<artifactId>onnxruntime-engine</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>1.7.36</version>
</dependency>
<!-- Lombok for reducing boilerplate code -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<repositories>
<repository>
<id>central</id>
<url>https://repo.maven.apache.org/maven2</url>
</repository>
<repository>
<id>aliyun</id>
<url>https://maven.aliyun.com/repository/public</url>
</repository>
<repository>
<id>jitpack.io</id>
<url>https://jitpack.io</url>
</repository>
<repository>
<id>oss.sonatype.org</id>
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
</repository>
</repositories>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>
数据集下载及处理完整代码
import os
import pickle
import numpy as np
from PIL import Image
import urllib.request
import tarfile
# 下载并解压CIFAR-100数据集
def download_and_extract_cifar100(data_dir):
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
file_path = os.path.join(data_dir, "cifar-100-python.tar.gz")
if not os.path.exists(file_path):
urllib.request.urlretrieve(url, file_path)
print("Downloaded CIFAR-100 dataset.")
tar = tarfile.open(file_path)
tar.extractall(path=data_dir)
tar.close()
print("Extracted CIFAR-100 dataset.")
# 加载CIFAR-100数据集
def load_cifar100_batch(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# 保存图像到指定目录
def save_images_from_cifar100(data, labels, label_names, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for i in range(len(data)):
img = data[i]
label = labels[i]
label_name = label_names[label]
img = img.reshape(3, 32, 32).transpose(1, 2, 0)
img = Image.fromarray(img)
label_dir = os.path.join(output_dir, label_name)
if not os.path.exists(label_dir):
os.makedirs(label_dir)
img.save(os.path.join(label_dir, f'{i}.png'))
# 主函数
def main():
data_dir = 'C:/Users/Public'
cifar100_dir = os.path.join(data_dir, 'cifar-100-python')
# 下载并解压数据集
download_and_extract_cifar100(data_dir)
train_file = os.path.join(cifar100_dir, 'train')
test_file = os.path.join(cifar100_dir, 'test')
meta_file = os.path.join(cifar100_dir, 'meta')
print("Train file path:", train_file)
print("Test file path:", test_file)
print("Meta file path:", meta_file)
if not os.path.exists(train_file):
raise FileNotFoundError(f"Train file not found: {train_file}")
if not os.path.exists(test_file):
raise FileNotFoundError(f"Test file not found: {test_file}")
if not os.path.exists(meta_file):
raise FileNotFoundError(f"Meta file not found: {meta_file}")
train_data = load_cifar100_batch(train_file)
test_data = load_cifar100_batch(test_file)
meta_data = load_cifar100_batch(meta_file)
label_names = [label.decode('utf-8') for label in meta_data[b'fine_label_names']]
# 保存训练集图像
save_images_from_cifar100(train_data[b'data'], train_data[b'fine_labels'], label_names, os.path.join(data_dir, 'cifar100_train_images'))
# 保存测试集图像
save_images_from_cifar100(test_data[b'data'], test_data[b'fine_labels'], label_names, os.path.join(data_dir, 'cifar100_test_images'))
if __name__ == "__main__":
main()