Spring Boot中集成DJL运行Python PyTorch模型:MNIST实战

Java Spring Boot 使用DJL 部署python训练的PyTorch模型(MNIST)

Java 使用 DJL 训练模型:https://blog.csdn.net/xundh/category_11361043.html?spm=1001.2014.3001.5515

DJL官网:https://docs.djl.ai/index.html

Python 训练Pytorch模型

本项目采用
PyTorch==1.10.0
版本训练。

pytorch                   1.10.0          py3.9_cuda11.3_cudnn8_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
requests                  2.28.1                   pypi_0    pypi
scipy                     1.9.3                    pypi_0    pypi
setuptools                65.6.3             pyhd8ed1ab_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
six                       1.16.0                   pypi_0    pypi
tbb                       2021.7.0             h91493d7_1    conda-forge
tk                        8.6.12               h8ffe710_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
torchaudio                0.10.0               py39_cu113    pytorch
torchsummary              1.5.1                    pypi_0    pypi
torchvision               0.11.0               py39_cu113    pytorch

修改模型保存方法:

model.eval()  # 模型验证模式
example = torch.rand(1, 1, 28, 28).to(device)  # 模型输入层
traced_script_module = torch.jit.trace(model, example)  # trace
traced_script_module.save('models/{}_model.pt'.format(val_ac))  # 保存模型

Java Spring Boot使用DJL调用模型

pom.xml

<!-- djl 依赖 -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.19.0</version>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>1.12.1-0.19.0</version>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.19.0</version>
</dependency>

<dependency>
<groupId>ai.djl.opencv</groupId>
<artifactId>opencv</artifactId>

根据提供的引用内容,没有找到与DJL运行yolo v8模型相关的信息。不过,DJL是一个基于Java的深度学习框架,支持多种深度学习模型,包括图像分类、目标检测等。如果您想在DJL运行yolo v8模型,可以参考以下步骤: 1.安装DJL和yolo v8模型 ```shell pip install djl tensorflow tensorflow_hub ``` ```shell wget https://storage.googleapis.com/tfhub-modules/google/yolo_v4/2.tar.gz tar -xvf 2.tar.gz ``` 2.加载模型 ```python import ai.djl.Model import ai.djl.basicmodelzoo.cv.object_detection.YoloV5 import ai.djl.engine.Engine import ai.djl.ndarray.NDList import ai.djl.ndarray.types.Shape import ai.djl.training.util.ProgressBar import ai.djl.translate.Pipeline import ai.djl.translate.TranslateException import ai.djl.translate.Translator import ai.djl.translate.TranslatorContext import ai.djl.util.Utils import java.io.IOException import java.nio.file.Path import java.nio.file.Paths import java.util.ArrayList import java.util.Collections import java.util.List def loadModel() throws IOException { Path modelDir = Paths.get("yolo_v8"); Model model = Model.newInstance("yolo_v8"); model.setBlock(new YoloV5(80)); model.load(modelDir, "yolov5s-640"); return model; } ``` 3.创建Translator ```python import ai.djl.modality.cv.Image import ai.djl.modality.cv.output.DetectedObjects import ai.djl.modality.cv.translator.SingleShotDetectionTranslator import ai.djl.modality.cv.translator.SingleShotDetectionTranslator.Builder import ai.djl.modality.cv.translator.TranslatorUtils import ai.djl.ndarray.NDList import ai.djl.ndarray.types.DataType import ai.djl.ndarray.types.Shape import ai.djl.repository.zoo.Criteria import ai.djl.repository.zoo.ModelZoo import ai.djl.training.util.ProgressBar import ai.djl.translate.Pipeline import ai.djl.translate.Translator import ai.djl.translate.TranslatorContext import java.awt.image.BufferedImage import java.io.IOException import java.nio.file.Path import java.util.ArrayList import java.util.Collections import java.util.List def createTranslator() { Pipeline pipeline = new Pipeline(); pipeline.add(new Resize(640, 640)); pipeline.add(new ToTensor()); return SingleShotDetectionTranslator.builder() .setPipeline(pipeline) .optSynset(Collections.emptyList()) .optThreshold(0.5f) .build(); } ``` 4.运行模型 ```python import ai.djl.Application import ai.djl.Model import ai.djl.modality.cv.Image import ai.djl.modality.cv.output.DetectedObjects import ai.djl.modality.cv.translator.SingleShotDetectionTranslator import ai.djl.ndarray.NDList import ai.djl.ndarray.types.DataType import ai.djl.ndarray.types.Shape import ai.djl.repository.zoo.Criteria import ai.djl.repository.zoo.ModelZoo import ai.djl.training.util.ProgressBar import ai.djl.translate.Pipeline import ai.djl.translate.Translator import ai.djl.translate.TranslatorContext import java.awt.image.BufferedImage import java.io.IOException import java.nio.file.Path import java.util.ArrayList import java.util.Collections import java.util.List def runModel() throws IOException, TranslateException { Model model = loadModel(); Translator<Image, DetectedObjects> translator = createTranslator(); Criteria<Image, DetectedObjects> criteria = Criteria.builder() .optApplication(Application.CV.OBJECT_DETECTION) .setTypes(Image.class, DetectedObjects.class) .optModel(model) .optTranslator(translator) .build(); try (ZooModel<Image, DetectedObjects> objDetectionModel = ModelZoo.loadModel(criteria)) { Path imagePath = Paths.get("test.jpg"); BufferedImage img = ImageIO.read(imagePath.toFile()); Image input = ImageFactory.getInstance().fromImage(img); DetectedObjects detections = objDetectionModel.predict(input); System.out.println(detections); } } ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值