跟着GPT学习——整合Java、resnet和CIFAR100数据集实现推理

我觉得GPT老师的语言功底比大多数的博客主要好(包括我自己),阅读起来更易理解,而且哪里不明白还可以直接问gpt老师,孜孜不倦,尽心尽责,全天待命,究极贴心。有这么厉害的一个老师,不学习简直暴殄天物。

于是乎我准备立一个flag,挑战跟着GPT老师学习365天,每天我都会整理自己的学习心得和脉络(文字大部分都是GPT直接生成的,我觉得比我自己写肯定好多了)感谢gpt老师!跪谢

全系列文章:跟着GPT学习-AI系列

CIFAR数据集训练的resnet模型使用了下面网站提供的onnx模型:
https://github.com/Lucas110550/CIFAR100_TinyImageNet_ResNet

最近临时接到一个任务,基于JAVA集成onnx模型,并进行推理,集成的整个过程非常痛苦,JAVA DJL库源代码中有死循环的错误,且各种算子不支持。经过各种改造,最后总算跑通了整个流程,发出来给需要的你们参考下吧~~

ImageClassificationController 详解文档

概述

ImageClassificationController 是一个基于 Spring Boot 的 REST 控制器,用于处理图像分类任务。它使用预训练的 ONNX 模型,并通过 DJL(Deep Java Library)加载模型、处理图像和进行预测。这个控制器提供了两个主要的功能端点:一个用于对单个图像进行分类,另一个用于测试整个数据集的模型准确性。

主要功能

  1. 分类单个图像:提供一个端点,用户可以通过该端点上传图像路径以进行分类。
  2. 测试所有图像:提供一个端点,测试模型在整个测试数据集上的准确性。

代码结构

1. classify 方法

通过 /classify GET 请求进行单个图像分类。

  • 步骤
    1. 加载 synset 文件,这个文件包含模型的所有分类标签。
    2. 设置翻译器 CIFAR100Translator,用于处理图像的输入和输出。
    3. 设置模型的加载标准,包括模型路径、翻译器和进度条。
    4. 加载模型并创建预测器。
    5. 读取图像文件并转换为 Image 对象。
    6. 使用预测器进行分类,并返回分类结果。
2. testAll 方法

通过 /testAll GET 请求对整个数据集进行测试。

  • 步骤
    1. 加载 synset 文件。
    2. 设置翻译器 CIFAR100Translator
    3. 设置模型的加载标准。
    4. 加载模型并创建预测器。
    5. 读取测试数据集目录,遍历每个分类文件夹和其中的图像文件。
    6. 对每张图像进行预测,并统计预测准确性。
    7. 计算并返回模型在测试集上的准确率。

CIFAR100Translator

用于处理图像输入和分类输出。

  • processInput 方法

    1. Image 对象转换为 NDArray
    2. 手动进行归一化并重新排列数据,因为 NDArray 无法直接使用 divexpsoftmax 等方法。
    3. 使用 CIFAR-100 数据集的均值和标准差进行归一化。
    4. 创建归一化后的 NDArray 并返回。
  • processOutput 方法

    1. 获取模型输出的 NDArray
    2. 手动计算每个类的概率,包括最大值、指数和归一化,因为 NDArray 无法直接使用 expsoftmax 等方法。
    3. 将输出转换为分类结果,包括类别名称和概率。

详细步骤说明

1. 加载 synset 文件

loadSynset 方法读取包含分类标签的文件,并将其存储在列表中。

private List<String> loadSynset(String synsetPath) throws IOException {
    List<String> synset = new ArrayList<>();
    try (BufferedReader br = new BufferedReader(new FileReader(synsetPath))) {
        String line;
        while ((line = br.readLine()) != null) {
            synset.add(line);
        }
    }
    return synset;
}
2. 手动归一化与重新排列数据

processInput 方法中,手动进行归一化,因为 NDArray 无法直接使用相关方法。

// 手动归一化并重新排列数据
int height = 32;
int width = 32;
int channels = 3;
byte[] rawData = array.toByteArray();
float[] data = new float[height * width * channels];
float[] mean = {0.5071f, 0.4865f, 0.4409f};
float[] std = {0.2673f, 0.2564f, 0.2761f};

// 归一化过程
for (int c = 0; c < channels; c++) {
    for (int h = 0; h < height; h++) {
        for (int w = 0; w < width; w++) {
            int index = c * height * width + h * width + w;
            int rawIndex = (h * width + w) * channels + c;
            data[index] = ((rawData[rawIndex] & 0xFF) / 255.0f - mean[c]) / std[c];
        }
    }
}
3. 手动计算最大值、指数和归一化

processOutput 方法中,手动计算输出的概率分布。

// 手动计算最大值、指数和归一化
float[] maxValues = new float[batchSize];
float[] expData = new float[outputData.length];
float[] sumExpValues = new float[batchSize];

for (int i = 0; i < batchSize; i++) {
    float maxVal = Float.NEGATIVE_INFINITY;
    for (int j = 0; j < numClasses; j++) {
        maxVal = Math.max(maxVal, outputData[i * numClasses + j]);
    }
    maxValues[i] = maxVal;
}

for (int i = 0; i < batchSize; i++) {
    float sumExp = 0.0f;
    for (int j = 0; j < numClasses; j++) {
        int index = i * numClasses + j;
        expData[index] = (float) Math.exp(outputData[index] - maxValues[i]);
        sumExp += expData[index];
    }
    sumExpValues[i] = sumExp;
}

float[] probabilitiesData = new float[outputData.length];
for (int i = 0; i < batchSize; i++) {
    for (int j = 0; j < numClasses; j++) {
        int index = i * numClasses + j;
        probabilitiesData[index] = expData[index] / sumExpValues[i];
    }
}

内存使用情况监控

通过 printMemoryUsage 方法打印 Java 虚拟机中的内存使用情况,以便在调试和优化时使用。

private void printMemoryUsage(String phase) {
    int mb = 1024 * 1024;
    Runtime runtime = Runtime.getRuntime();
    System.out.println("##### 内存使用情况 (" + phase + ") #####");
    System.out.println("最大内存: " + runtime.maxMemory() / mb + " MB");
    System.out.println("已分配内存: " + runtime.totalMemory() / mb + " MB");
    System.out.println("已分配内存中的剩余空间: " + runtime.freeMemory() / mb + " MB");
    System.out.println("实际使用的内存: " + (runtime.totalMemory() - runtime.freeMemory()) / mb + " MB");
}

总结

本控制器提供了高效的图像分类功能,尽管 NDArray 无法直接使用 divexpsoftmax 等方法,但通过手写相关代码实现了这些功能。通过详细的步骤介绍和代码示例,能够清晰地理解整个图像分类过程。

完整代码

package com.example.demo.controller;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.*;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import ai.djl.ndarray.types.Shape;

import java.io.*;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;



@RestController
public class ImageClassificationController {

    public static final String ONNX_PATH = "C:/Users/jike4/Downloads/CIFAR100_resnet_large.onnx";
//    public static final String ONNX_PATH = "C:/Users/jike4/Downloads/resnet18-v1-7.onnx";

    @GetMapping("/classify")
    public String classify(@RequestParam String imagePath) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {

        System.out.println("Loading synset...");
        List<String> synset = loadSynset("C:/Users/jike4/synset.txt");

        System.out.println("Synset loaded: " + synset.size() + " classes");

        Translator<Image, Classifications> translator = new CIFAR100Translator(synset);

        System.out.println("Setting up criteria...");
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optModelPath(Paths.get("C:/Users/jike4/Downloads/CIFAR100_resnet_large.onnx"))
                .optTranslator(translator)
                .optProgress(new ProgressBar())
                .optEngine("OnnxRuntime")
                .build();

        System.out.println("Loading model...");
        try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
             Predictor<Image, Classifications> predictor = model.newPredictor()) {

            System.out.println("Model loaded, loading image...");
            File imgFile = new File(imagePath);
            System.out.println(imgFile.getPath());
            Image img = ImageFactory.getInstance().fromInputStream(new FileInputStream(imgFile));
            System.out.println("Image loaded, running prediction...");
            Classifications classifications = predictor.predict(img);
            System.out.println("Prediction complete");
            return classifications.toString();
        }
    }

    @GetMapping("/testAll")
    public String testAll() throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        System.out.println("Loading synset...");
        List<String> synset = loadSynset("C:/Users/jike4/synset.txt");
        System.out.println("Synset loaded: " + synset.size() + " classes");

        Translator<Image, Classifications> translator = new CIFAR100Translator(synset);

        System.out.println("Setting up criteria...");
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optModelPath(Paths.get("C:/Users/jike4/Downloads/CIFAR100_resnet_large.onnx"))
                .optTranslator(translator)
                .optProgress(new ProgressBar())
                .optEngine("OnnxRuntime")
                .build();

        System.out.println("Loading model...");
        try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
             Predictor<Image, Classifications> predictor = model.newPredictor()) {

            System.out.println("Model loaded, loading test dataset...");
            File testDir = new File("C:/Users/Public/cifar100_test_images");
            File[] labelDirs = testDir.listFiles(File::isDirectory);

            int correctCount = 0;
            int totalCount = 0;
            for (File labelDir : labelDirs) {
                String trueLabelName = labelDir.getName();
                int trueLabel = synset.indexOf(trueLabelName);
                File[] imageFiles = labelDir.listFiles((dir, name) -> name.toLowerCase().endsWith(".png"));

                for (File imageFile : imageFiles) {
                    Image img = ImageFactory.getInstance().fromFile(imageFile.toPath());
                    Classifications classifications = predictor.predict(img);
                    int predictedLabel = synset.indexOf(classifications.best().getClassName());
                    if (predictedLabel == trueLabel) {
                        correctCount++;
                    }
                    totalCount++;
                }
            }

            double accuracy = (double) correctCount / totalCount;
            System.out.println("Test accuracy: " + accuracy);
            return "Test accuracy: " + accuracy;
        }
    }

    private List<String> loadSynset(String synsetPath) throws IOException {
        List<String> synset = new ArrayList<>();
        try (BufferedReader br = new BufferedReader(new FileReader(synsetPath))) {
            String line;
            while ((line = br.readLine()) != null) {
                synset.add(line);
            }
        }
        return synset;
    }
}

class CIFAR100Translator implements NoBatchifyTranslator<Image, Classifications> {

    private List<String> synset;

    public CIFAR100Translator(List<String> synset) {
        this.synset = synset;
    }

    @Override
    public NDList processInput(TranslatorContext ctx, Image input) throws IOException {
        // Convert Image to NDArray directly
        NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);

//        // Print NDArray info for debugging
//        System.out.println("NDArray Shape: " + array.getShape());
//        System.out.println("NDArray DataType: " + array.getDataType());

        // 手动归一化并重新排列数据
        int height = 32;
        int width = 32;
        int channels = 3;

        // 获取图像数据为字节数组
        byte[] rawData = array.toByteArray();

        // 创建浮点数组用于存储归一化后的数据
        float[] data = new float[height * width * channels];

        // 使用CIFAR-100的均值和标准差进行归一化
        float[] mean = {0.5071f, 0.4865f, 0.4409f};
        float[] std = {0.2673f, 0.2564f, 0.2761f};

        // 归一化过程
        for (int c = 0; c < channels; c++) {
            for (int h = 0; h < height; h++) {
                for (int w = 0; w < width; w++) {
                    int index = c * height * width + h * width + w;
                    int rawIndex = (h * width + w) * channels + c;
                    data[index] = ((rawData[rawIndex] & 0xFF) / 255.0f - mean[c]) / std[c];
                }
            }
        }


        NDArray tensor = ctx.getNDManager().create(data, new Shape(1, 3, 32, 32));

        // Print reshaped NDArray info
//        System.out.println("Reshaped NDArray Shape: " + tensor.getShape());
//        System.out.println("Reshaped NDArray DataType: " + tensor.getDataType());

        return new NDList(tensor);
    }

    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        // 获取 NDArray
        NDArray outputArray = list.singletonOrThrow();

        // 转换为 float 数组
        float[] outputData = outputArray.toFloatArray();
        int batchSize = (int) outputArray.getShape().get(0);
        int numClasses = (int) outputArray.getShape().get(1);

        // 手动计算最大值、指数和归一化
        float[] maxValues = new float[batchSize];
        float[] expData = new float[outputData.length];
        float[] sumExpValues = new float[batchSize];

        for (int i = 0; i < batchSize; i++) {
            float maxVal = Float.NEGATIVE_INFINITY;
            for (int j = 0; j < numClasses; j++) {
                maxVal = Math.max(maxVal, outputData[i * numClasses + j]);
            }
            maxValues[i] = maxVal;
        }

        for (int i = 0; i < batchSize; i++) {
            float sumExp = 0.0f;
            for (int j = 0; j < numClasses; j++) {
                int index = i * numClasses + j;
                expData[index] = (float) Math.exp(outputData[index] - maxValues[i]);
                sumExp += expData[index];
            }
            sumExpValues[i] = sumExp;
        }

        float[] probabilitiesData = new float[outputData.length];
        for (int i = 0; i < batchSize; i++) {
            for (int j = 0; j < numClasses; j++) {
                int index = i * numClasses + j;
                probabilitiesData[index] = expData[index] / sumExpValues[i];
            }
        }

        // 将输出转换为分类结果
        List<String> classNames = new ArrayList<>();
        List<Double> probs = new ArrayList<>();
        for (int i = 0; i < numClasses; i++) {
            classNames.add(synset.get(i));
            probs.add((double) probabilitiesData[i]);
        }

        return new Classifications(classNames, probs);
    }



    private void printMemoryUsage(String phase) {
        int mb = 1024 * 1024;

        // 获取 Java 虚拟机中的内存使用情况
        Runtime runtime = Runtime.getRuntime();

        // 打印内存使用情况
        System.out.println("##### 内存使用情况 (" + phase + ") #####");
        System.out.println("最大内存: " + runtime.maxMemory() / mb + " MB");
        System.out.println("已分配内存: " + runtime.totalMemory() / mb + " MB");
        System.out.println("已分配内存中的剩余空间: " + runtime.freeMemory() / mb + " MB");
        System.out.println("实际使用的内存: " + (runtime.totalMemory() - runtime.freeMemory()) / mb + " MB");
    }
}

pom依赖

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<parent>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-parent</artifactId>
		<version>2.3.0.RELEASE</version>
		<relativePath/> <!-- lookup parent from repository -->
	</parent>
	<groupId>com.qiyuan</groupId>
	<artifactId>demo</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	<name>demo</name>
	<description>demo</description>
	<url/>
	<licenses>
		<license/>
	</licenses>
	<developers>
		<developer/>
	</developers>
	<scm>
		<connection/>
		<developerConnection/>
		<tag/>
		<url/>
	</scm>
	<properties>
		<java.version>11</java.version>
	</properties>
	<dependencies>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-web</artifactId>
		</dependency>
		<!-- DJL API -->
		<!-- DJL API -->
		<dependency>
			<groupId>ai.djl</groupId>
			<artifactId>api</artifactId>
			<version>0.28.0</version>
		</dependency>

		<!-- DJL ONNX Runtime Engine -->
		<dependency>
			<groupId>ai.djl.onnxruntime</groupId>
			<artifactId>onnxruntime-engine</artifactId>
			<version>0.28.0</version>
		</dependency>

		<dependency>
			<groupId>org.slf4j</groupId>
			<artifactId>slf4j-simple</artifactId>
			<version>1.7.36</version>
		</dependency>

		<!-- Lombok for reducing boilerplate code -->
		<dependency>
			<groupId>org.projectlombok</groupId>
			<artifactId>lombok</artifactId>
			<scope>provided</scope>
		</dependency>


		<dependency>
			<groupId>org.projectlombok</groupId>
			<artifactId>lombok</artifactId>
			<optional>true</optional>
		</dependency>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
			<scope>test</scope>
		</dependency>
	</dependencies>
	<repositories>
		<repository>
			<id>central</id>
			<url>https://repo.maven.apache.org/maven2</url>
		</repository>
		<repository>
			<id>aliyun</id>
			<url>https://maven.aliyun.com/repository/public</url>
		</repository>
		<repository>
			<id>jitpack.io</id>
			<url>https://jitpack.io</url>
		</repository>
		<repository>
			<id>oss.sonatype.org</id>
			<url>https://oss.sonatype.org/content/repositories/snapshots</url>
		</repository>
	</repositories>

	<build>
		<plugins>
			<plugin>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-maven-plugin</artifactId>
				<configuration>
					<excludes>
						<exclude>
							<groupId>org.projectlombok</groupId>
							<artifactId>lombok</artifactId>
						</exclude>
					</excludes>
				</configuration>
			</plugin>
		</plugins>
	</build>

</project>

数据集下载及处理完整代码

import os
import pickle
import numpy as np
from PIL import Image
import urllib.request
import tarfile

# 下载并解压CIFAR-100数据集
def download_and_extract_cifar100(data_dir):
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    file_path = os.path.join(data_dir, "cifar-100-python.tar.gz")
    if not os.path.exists(file_path):
        urllib.request.urlretrieve(url, file_path)
        print("Downloaded CIFAR-100 dataset.")

    tar = tarfile.open(file_path)
    tar.extractall(path=data_dir)
    tar.close()
    print("Extracted CIFAR-100 dataset.")

# 加载CIFAR-100数据集
def load_cifar100_batch(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

# 保存图像到指定目录
def save_images_from_cifar100(data, labels, label_names, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for i in range(len(data)):
        img = data[i]
        label = labels[i]
        label_name = label_names[label]

        img = img.reshape(3, 32, 32).transpose(1, 2, 0)
        img = Image.fromarray(img)

        label_dir = os.path.join(output_dir, label_name)
        if not os.path.exists(label_dir):
            os.makedirs(label_dir)
        
        img.save(os.path.join(label_dir, f'{i}.png'))

# 主函数
def main():
    data_dir = 'C:/Users/Public'
    cifar100_dir = os.path.join(data_dir, 'cifar-100-python')

    # 下载并解压数据集
    download_and_extract_cifar100(data_dir)

    train_file = os.path.join(cifar100_dir, 'train')
    test_file = os.path.join(cifar100_dir, 'test')
    meta_file = os.path.join(cifar100_dir, 'meta')

    print("Train file path:", train_file)
    print("Test file path:", test_file)
    print("Meta file path:", meta_file)

    if not os.path.exists(train_file):
        raise FileNotFoundError(f"Train file not found: {train_file}")
    if not os.path.exists(test_file):
        raise FileNotFoundError(f"Test file not found: {test_file}")
    if not os.path.exists(meta_file):
        raise FileNotFoundError(f"Meta file not found: {meta_file}")

    train_data = load_cifar100_batch(train_file)
    test_data = load_cifar100_batch(test_file)
    meta_data = load_cifar100_batch(meta_file)

    label_names = [label.decode('utf-8') for label in meta_data[b'fine_label_names']]

    # 保存训练集图像
    save_images_from_cifar100(train_data[b'data'], train_data[b'fine_labels'], label_names, os.path.join(data_dir, 'cifar100_train_images'))

    # 保存测试集图像
    save_images_from_cifar100(test_data[b'data'], test_data[b'fine_labels'], label_names, os.path.join(data_dir, 'cifar100_test_images'))

if __name__ == "__main__":
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值