SpringBoot+DJL集成深度学习实现鞋类图片分类

一、概述

Deep Java Library (DJL) 是一个用于深度学习的开源、高级、引擎无关的 Java 框架。DJL 被设计成易于入门,并且对于 Java 开发人员来说易于使用。DJL 提供了与其他常规 Java 库一样的本地 Java 开发体验和函数。

你不必成为机器学习/深度学习的专家才能开始。您可以将现有的 Java 专业知识用作学习和使用机器学习和深度学习的入口。您可以使用您喜欢的 IDE 来构建、训练和部署模型。DJL 使得将这些模型与 Java 应用程序集成变得非常容易。

因为 DJL 是深度学习引擎不可知论者,所以在创建项目时不必在引擎之间做出选择。你可以在任何时候切换引擎。为了确保最佳性能,DJL 还提供基于硬件配置的自动 CPU/GPU 选择。

环境及主要软件版本说明

  • java8+
  • springboot 版本 2.7.16-SNAPSHOT
  • djl 版本 0.23.0

二、项目构建运行

下载测试图片数据集

下载 ut-zap50k-images-square 图片集 https://vision.cs.utexas.edu/projects/finegrained/utzap50k/

创建 springboot 项目

image.png

pom 关键依赖信息

<dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <!-- DJL -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>

        <!-- pytorch-engine-->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <scope>runtime</scope>
        </dependency>

        <!-- knife4j -->
        <dependency>
            <groupId>com.github.xiaoymin</groupId>
            <artifactId>knife4j-micro-spring-boot-starter</artifactId>
            <version>3.0.2</version>
        </dependency>
        <dependency>
            <groupId>com.github.xiaoymin</groupId>
            <artifactId>knife4j-spring-boot-starter</artifactId>
            <version>3.0.2</version>
        </dependency>

</dependencies>

<profiles>
    <profile>
        <id>windows</id>
        <activation>
            <activeByDefault>true</activeByDefault>
        </activation>
        <dependencies>
            <!-- Windows CPU -->
            <dependency>
                <groupId>ai.djl.pytorch</groupId>
                <artifactId>pytorch-native-cpu</artifactId>
                <classifier>win-x86_64</classifier>
                <scope>runtime</scope>
                <version>2.0.1</version>
            </dependency>
            <dependency>
                <groupId>ai.djl.pytorch</groupId>
                <artifactId>pytorch-jni</artifactId>
                <version>2.0.1-0.23.0</version>
                <scope>runtime</scope>
            </dependency>
        </dependencies>
    </profile>
    <profile>
        <id>centos7</id>
        <activation>
            <activeByDefault>false</activeByDefault>
        </activation>
        <dependencies>
            <!-- For Pre-CXX11 build (CentOS7)-->
            <dependency>
                <groupId>ai.djl.pytorch</groupId>
                <artifactId>pytorch-native-cpu-precxx11</artifactId>
                <classifier>linux-x86_64</classifier>
                <version>2.0.1</version>
                <scope>runtime</scope>
            </dependency>
            <dependency>
                <groupId>ai.djl.pytorch</groupId>
                <artifactId>pytorch-jni</artifactId>
                <version>2.0.1-0.23.0</version>
                <scope>runtime</scope>
            </dependency>
        </dependencies>
    </profile>
    <profile>
        <id>linux</id>
        <activation>
            <activeByDefault>false</activeByDefault>
        </activation>
        <dependencies>
            <!-- Linux CPU -->
            <dependency>
                <groupId>ai.djl.pytorch</groupId>
                <artifactId>pytorch-native-cpu</artifactId>
                <classifier>linux-x86_64</classifier>
                <scope>runtime</scope>
                <version>2.0.1</version>
            </dependency>
            <dependency>
                <groupId>ai.djl.pytorch</groupId>
                <artifactId>pytorch-jni</artifactId>
                <version>2.0.1-0.23.0</version>
                <scope>runtime</scope>
            </dependency>
        </dependencies>
    </profile>
    <profile>
        <id>aarch64</id>
        <activation>
            <activeByDefault>false</activeByDefault>
        </activation>
        <dependencies>
            <!-- For aarch64 build-->
            <dependency>
                <groupId>ai.djl.pytorch</groupId>
                <artifactId>pytorch-native-cpu-precxx11</artifactId>
                <classifier>linux-aarch64</classifier>
                <scope>runtime</scope>
                <version>2.0.1</version>
            </dependency>
            <dependency>
                <groupId>ai.djl.pytorch</groupId>
                <artifactId>pytorch-jni</artifactId>
                <version>2.0.1-0.23.0</version>
                <scope>runtime</scope>
            </dependency>
        </dependencies>
    </profile>
</profiles>

<dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>bom</artifactId>
            <version>0.23.0</version>
            <type>pom</type>
            <scope>import</scope>
        </dependency>
    </dependencies>
</dependencyManagement>

application.yml

server:
  port: 8888
spring:
  application:
    name: djl-image-classification-demo
  servlet:
    multipart:
      max-file-size: 100MB
      max-request-size: 100MB
  mvc:
    pathmatch:
      matching-strategy: ant_path_matcher
knife4j:
  enable: true
djl:
  num-of-output: 4

Models.java

package com.example.djl.demo;

import ai.djl.Model;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;

import java.io.IOException;
import java.io.Writer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

/**
 * A helper class loads and saves model.
 * @author lakudouzi
 */
public final class Models {

    // the height and width for pre-processing of the image
    public static final int IMAGE_HEIGHT = 100;
    public static final int IMAGE_WIDTH = 100;

    // the name of the model
    public static final String MODEL_NAME = "shoeclassifier";

    private Models() {
    }

    public static Model getModel(int numOfOutput) {
        // create new instance of an empty model
        Model model = Model.newInstance(MODEL_NAME);

        // Block is a composable unit that forms a neural network; combine them like Lego blocks
        // to form a complex network
        Block resNet50 =
                ResNetV1.builder() // construct the network
                        .setImageShape(new Shape(3, IMAGE_HEIGHT, IMAGE_WIDTH))
                        .setNumLayers(50)
                        .setOutSize(numOfOutput)
                        .build();

        // set the neural network to the model
        model.setBlock(resNet50);
        return model;
    }

    public static void saveSynset(Path modelDir, List<String> synset) throws IOException {
        Path synsetFile = modelDir.resolve("synset.txt");
        try (Writer writer = Files.newBufferedWriter(synsetFile)) {
            writer.write(String.join("\n", synset));
        }
    }
}

ImageClassificationServiceImpl.java

package com.example.djl.demo.service.impl;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.*;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import com.example.djl.demo.Models;
import com.example.djl.demo.service.ImageClassificationService;
import lombok.Cleanup;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;

/**
 * @author lakudouzi
 */
@Slf4j
@Service
public class ImageClassificationServiceImpl implements ImageClassificationService {

    // represents number of training samples processed before the model is updated
    private static final int BATCH_SIZE = 32;

    // the number of passes over the complete dataset
    private static final int EPOCHS = 2;

    //the number of classification labels: boots, sandals, shoes, slippers
    @Value("${djl.num-of-output:4}")
    public int numOfOutput;

    @Override
    public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException {
        @Cleanup
        InputStream is = image.getInputStream();
        Path modelDir = Paths.get(modePath);
        BufferedImage bi = ImageIO.read(is);
        Image img = ImageFactory.getInstance().fromImage(bi);
        // empty model instance
        try (Model model = Models.getModel(numOfOutput)) {
            // load the model
            model.load(modelDir, Models.MODEL_NAME);
            // define a translator for pre and post processing
            // out of the box this translator converts images to ResNet friendly ResNet 18 shape
            Translator<Image, Classifications> translator =
                    ImageClassificationTranslator.builder()
                            .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
                            .addTransform(new ToTensor())
                            .optApplySoftmax(true)
                            .build();
            // run the inference using a Predictor
            try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
                // holds the probability score per label
                Classifications predictResult = predictor.predict(img);
                log.info("结果={}",predictResult.toJson());
                return predictResult.toJson();
            }
        }
    }

    @Override
    public String training(String datasetRoot, String modePath) throws TranslateException, IOException {
        log.info("图片数据集训练开始......图片数据集地址路径:{}",datasetRoot);
        // the location to save the model
        Path modelDir = Paths.get(modePath);

        // create ImageFolder dataset from directory
        ImageFolder dataset = initDataset(datasetRoot);
        // Split the dataset set into training dataset and validate dataset
        RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);

        // set loss function, which seeks to minimize errors
        // loss function evaluates model's predictions against the correct answer (during training)
        // higher numbers are bad - means model performed poorly; indicates more errors; want to
        // minimize errors (loss)
        Loss loss = Loss.softmaxCrossEntropyLoss();

        // setting training parameters (ie hyperparameters)
        TrainingConfig config = setupTrainingConfig(loss);

        try (Model model = Models.getModel(numOfOutput); // empty model instance to hold patterns
             Trainer trainer = model.newTrainer(config)) {
            // metrics collect and report key performance indicators, like accuracy
            trainer.setMetrics(new Metrics());

            Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT);

            // initialize trainer with proper input shape
            trainer.initialize(inputShape);

            // find the patterns in data
            EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]);

            // set model properties
            TrainingResult result = trainer.getTrainingResult();
            model.setProperty("Epoch", String.valueOf(EPOCHS));
            model.setProperty(
                    "Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy")));
            model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));

            // save the model after done training for inference later
            // model saved as shoeclassifier-0000.params
            model.save(modelDir, Models.MODEL_NAME);
            // save labels into model directory
            Models.saveSynset(modelDir, dataset.getSynset());
            log.info("图片数据集训练结束......");
            return String.join("\n", dataset.getSynset());
        }
    }

    private ImageFolder initDataset(String datasetRoot)
            throws IOException, TranslateException {
        ImageFolder dataset =
                ImageFolder.builder()
                        // retrieve the data
                        .setRepositoryPath(Paths.get(datasetRoot))
                        .optMaxDepth(10)
                        .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
                        .addTransform(new ToTensor())
                        // random sampling; don't process the data in order
                        .setSampling(BATCH_SIZE, true)
                        .build();
        dataset.prepare();
        return dataset;
    }

    private TrainingConfig setupTrainingConfig(Loss loss) {
        return new DefaultTrainingConfig(loss)
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging());
    }

}

ImageClassificationController.java

package com.example.djl.demo.rest;

import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import com.example.djl.demo.service.ImageClassificationService;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;

/**
 * @author lakudouzi
 */
@RestController
@RequiredArgsConstructor
public class ImageClassificationController {

    private final ImageClassificationService imageClassificationService;

    @PostMapping(path = "/analyze")
    @ApiOperation("图片对象识别-上传图片识别")
    public String predict(@RequestPart("image") MultipartFile image,
                          @RequestParam(defaultValue = "/home/djl-test/models") String modePath)
            throws TranslateException,
            MalformedModelException,
            IOException {
        return imageClassificationService.predict(image, modePath);
    }

    @PostMapping(path = "/training")
    @ApiOperation("图片对象识别-训练图片数据集")
    public String training(@RequestParam(defaultValue = "/home/djl-test/images-test")
                           String datasetRoot,
                           @RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, IOException {
        return imageClassificationService.training(datasetRoot, modePath);
    }

    @GetMapping("/download")
    @ApiOperation(value = "图片对象识别-下载测试图片", produces = "application/octet-stream")
    public ResponseEntity<Resource> downloadFile(@RequestParam(defaultValue = "/home/djl-test/images-test") String directoryPath) {
        List<String> imgPathList = new ArrayList<>();
        try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) {
            // Filter only regular files (excluding directories)
            paths.filter(Files::isRegularFile)
                    .forEach(c-> imgPathList.add(c.toString()));
        } catch (IOException e) {
            return ResponseEntity.status(500).build();
        }
        Random random = new Random();
        String filePath = imgPathList.get(random.nextInt(imgPathList.size()));
        Path file = Paths.get(filePath);
        Resource resource = new FileSystemResource(file.toFile());

        if (!resource.exists()) {
            return ResponseEntity.notFound().build();
        }
        HttpHeaders headers = new HttpHeaders();
        headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + file.getFileName().toString());
        headers.add(HttpHeaders.CONTENT_TYPE, MediaType.IMAGE_JPEG_VALUE);

        try {
            return ResponseEntity.ok()
                    .headers(headers)
                    .contentLength(resource.contentLength())
                    .body(resource);
        } catch (IOException e) {
            return ResponseEntity.status(500).build();
        }
    }

}

集成knife4j方便测试,启动项目,访问 http://localhost:8888/doc.html 进行测试。

测试效果

根据下载的数据集训练生成模型

image.png
本地测试采用window cpu进行训练

image.png

image.png

下载一张测试图片进行分类

image.png

测试图片属于的鞋子类型

image.png

image.png

由此可见该图片所属的鞋类为 Shoes,概率最高。

后续使用CentOS 7虚拟机也可以正常运行(项目profile 切换到centos7打包即可)。

参见:

  • https://docs.djl.ai/engines/pytorch/pytorch-engine/index.html
  • https://docs.djl.ai/docs/demos/footwear_classification/index.html#train-the-footwear-classification-model

至此就大功告成了,对于DJL还不是很熟悉,对于训练速度提升不知还有什么更好的方案,不知是否还有最优解。欢迎交流。

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值