yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo

前言

这篇博客主要是介绍如何通过djl在java中调用yolov5进行推理,顺便也学习了一下在java上的opencv api。
Deep Java Library是由亚马逊(Amazon)提供的一个深度学习工具包,能够让java开发者在java上调用目前主流的深度学习框架,像pytorch、tensorflow、mxnet、paddlepaddle(飞桨居然也有份😂),也包括onnx格式的模型。

环境

导出yolov5s模型

这次demo就直接使用yolov5s的预训练模型。yolov5项目本身就自带了非常完善的模型导出脚本,yolov5的5.0发行版也比之前的版本完善很多。
yolov5的模型导出脚本是models/export.py文件,
在这里插入图片描述
导出之前需要设置一下

  • 权重文件的位置
  • 输入图片的尺寸
  • 是否要输出bbox
  • 模型所在设备
    在这里插入图片描述
    上图红色的框按我的进行设置就行了,绿色的框根据自己的情况进行设置。
    设置好以后运行代码就可以在和权重文件相同的位置找到生成的torchscript模型权重。
    在这里插入图片描述

编写Maven项目

编写pom.xml文件

djl使用pytorch需要引入相关依赖

  • pytorch-model-zoo
  • pytorch-engine
  • pytorch-native-auto
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>xyz.hyhy</groupId>
    <artifactId>TestAI</artifactId>
    <version>1.0-SNAPSHOT</version>
    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <djl.version>0.11.0</djl.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>${djl.version}</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.8.1</version>
        </dependency>
    </dependencies>
</project>

引入opencv依赖

下载opencv

到官网下载opencv库
在这里插入图片描述

获取opencv的jar包和动态链接库dll文件

下载完会得到一个exe文件,实际只是个压缩包,解压后到build文件夹下将jar包和x64或x86文件夹下的dll文件一起复制到项目的lib文件夹下。dll文件根据自己系统是64位还是32位进行选择。
在这里插入图片描述

将lib文件夹添加为Library

在这里插入图片描述

将yolov5权重文件放到资源文件

将之前导出的yolov5s.torchscript.pt文件放到resources/yolov5s文件夹下。另外还要编写一个coco.names文件,用来说明分类任务的类名。
在这里插入图片描述
coco.names

person
bicycle
car
motorbike
aeroplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
sofa
pottedplant
bed
diningtable
toilet
tvmonitor
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

编写代码

package xyz.hyhy;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.DetectedObjects.DetectedObject;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.YoloV5Translator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import org.opencv.core.*;
import org.opencv.highgui.HighGui;
import org.opencv.imgproc.Imgproc;
import org.opencv.videoio.VideoCapture;
import xyz.hyhy.utils.MyUtils;

import java.io.IOException;

import static org.opencv.videoio.Videoio.CAP_ANY;

public class Main {

    static {
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
    }

    public static void main(String[] args) {
        Translator<Image, DetectedObjects> translator = YoloV5Translator.builder().optSynsetArtifactName("coco.names").build();
        Criteria<Image, DetectedObjects> criteria =
                Criteria.builder()
                        .setTypes(Image.class, DetectedObjects.class)
                        .optDevice(Device.cpu())
                        .optModelUrls(Main.class.getResource("/yolov5s").getPath())
                        .optModelName("yolov5s.torchscript.pt")
                        .optTranslator(translator)
                        .optEngine("PyTorch")
                        .build();
//        Criteria<Image, DetectedObjects> criteria =
//                Criteria.builder()
//                        .setTypes(Image.class, DetectedObjects.class)
//                        .optDevice(Device.cpu())
//                        .optModelUrls(Main.class.getResource("/yolov5").getPath())
//                        .optModelName("yolov5s.onnx")
//                        .optTranslator(translator)
//                        .optEngine("OnnxRuntime")
//                        .build();
        try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
            VideoCapture cap = new VideoCapture(CAP_ANY);
            if (!cap.isOpened()) {//isOpened函数用来判断摄像头调用是否成功
                System.out.println("Camera Error");//如果摄像头调用失败,输出错误信息
            } else {
                Mat frame = new Mat();//创建一个输出帧
                boolean flag = cap.read(frame);//read方法读取摄像头的当前帧
                while (flag) {
                    detect(frame, model);
                    HighGui.imshow("yolov5", frame);
                    HighGui.waitKey(20);
                    flag = cap.read(frame);
                }
            }

        } catch (RuntimeException | ModelException | TranslateException | IOException e) {
            e.printStackTrace();
        }
    }

    static Rect rect = new Rect();
    static Scalar color = new Scalar(0, 255, 0);

    static void detect(Mat frame, ZooModel<Image, DetectedObjects> model) throws IOException, ModelNotFoundException, MalformedModelException, TranslateException {
        Image img = MyUtils.mat2Image(frame);
        long startTime = System.currentTimeMillis();
        try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
            DetectedObjects results = predictor.predict(img);
//            System.out.println(results);
            for (DetectedObject obj : results.<DetectedObject>items()) {
                BoundingBox bbox = obj.getBoundingBox();
                Rectangle rectangle = bbox.getBounds();
                String showText = String.format("%s: %.2f", obj.getClassName(), obj.getProbability());
                rect.x = (int) rectangle.getX();
                rect.y = (int) rectangle.getY();
                rect.width = (int) rectangle.getWidth();
                rect.height = (int) rectangle.getHeight();
                // 画框

                Imgproc.rectangle(frame, rect, color, 2);
                //画名字
                Imgproc.putText(frame, showText,
                        new Point(rect.x, rect.y),
                        Imgproc.FONT_HERSHEY_COMPLEX,
                        rectangle.getWidth() / 200,
                        color);
            }
        }
        System.out.println(String.format("%.2f", 1000.0 / (System.currentTimeMillis() - startTime)));
    }
}


运行程序

程序启动时,会卡住一段时间,不过不要慌,因为djl需要下载pytorch的动态链接库,下载的位置在%USERPROFILE%\.djl.ai\pytorch目录下。可以看一下加速球的流量消耗或者到对应文件夹下确认是否有在下载。
下载的实际上就是libtorch里面的那些动态链接库。djl会根据你的系统自动选择下载合适的版本(应该)。
在这里插入图片描述
效果:
在这里插入图片描述

补充

之后测试了onnx的yolov5s模型,onnx的推理速度更快,速度大概是torchscript的3倍。

MyUtils.mat2Image

    public static Image mat2Image(Mat mat) {
        return ImageFactory.getInstance().fromImage(HighGui.toBufferedImage(mat));
    }
评论 106
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

虹幺

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值