1.主程序文件
package com.xxx.onnx;
import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.BufferedImageFactory;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.translator.YoloV5Translator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import org.bytedeco.ffmpeg.global.avutil;
import org.bytedeco.javacv.Java2DFrameUtils;
import org.bytedeco.javacv.*;
import org.bytedeco.opencv.opencv_core.Mat;
import org.opencv.core.Core;
import org.opencv.core.MatOfPoint;
import org.opencv.core.Scalar;
import org.opencv.imgproc.Imgproc;
import javax.swing.*;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
public class Rtsp {
private static final String RTSP = "rtsp://admin:admin1234@192.168.66.150:554/cam/realmonitor?channel=4&subtype=1";
private static final String path = "D:\\LIHAOWORK\\models\\yolov5-pt\\model\\person\\person.onnx";
private static final org.opencv.core.Point[] points = {new org.opencv.core.Point(0, 300),
new org.opencv.core.Point(350, 340),
new org.opencv.core.Point(400, 500),
new org.opencv.core.Point(0, 720),};
private static Predictor<Image, DetectedObjects> predictor;
private static DetectedObjects result;
private static float threshold = 0.2f;
private static int frameRate = 30;
private static int width = 640;
private static int height = 640;
private static void init(){
Translator<Image, DetectedObjects> translator = YoloV5Translator
.builder()
.optThreshold(threshold)
.optSynsetArtifactName("synset.txt")
.build();
YoloV5RelativeTranslator myTranslator = new YoloV5RelativeTranslator(translator);
try {
ZooModel<Image, DetectedObjects> model = Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
.optDevice(Device.cpu())
.optEngine("OnnxRuntime")
.setTypes(Image.class, DetectedObjects.class)
.optTranslator(myTranslator)
.optModelPath(Paths.get(path))
.optProgress(new ProgressBar())
.build().loadModel();
predictor = model.newPredictor();
System.out.println("模型加载完成");
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
System.out.println("底层库加载完成");
} catch (IOException e) {
e.printStackTrace();
} catch (ModelNotFoundException e) {
e.printStackTrace();
} catch (MalformedModelException e) {
e.printStackTrace();
}
}
public static void main(String[] args) {
System.out.println("开始抽帧");
FFmpegFrameGrabber grabber = null;
try {
grabber = FFmpegFrameGrabber.createDefault(RTSP);
grabber.setOption("rtsp_transport", "tcp");
grabber.setOption("stimeout", "5000000");
grabber.setPixelFormat(avutil.AV_PIX_FMT_RGB24);
grabber.setImageWidth(width);
grabber.setImageHeight(height);
grabber.setFrameRate(frameRate);
grabber.start();
System.out.println("初始化模型");
init();
System.out.println("播放窗口");
CanvasFrame canvasFrame = new CanvasFrame("摄像机");
canvasFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
canvasFrame.setAlwaysOnTop(true);
System.out.println("核心处理逻辑");
int i = 0;
while (true) {
Frame frame = grabber.grabFrame();
frame = processFrame(frame,i);
canvasFrame.showImage(frame);
i++;
if(i >= frameRate) i=0;
}
} catch (Exception e) {
System.out.println(e);
} finally {
}
}
private static Frame processFrame(Frame frame,int i) {
System.out.println("1(frame2Image)");
Long start1 = System.currentTimeMillis();
Image image = frame2Image(frame);
Long end1 = System.currentTimeMillis();
System.out.println("frame2Image耗时:"+(end1-start1)+"ms");
if(i%10==0){
try {
System.out.println("2(推理)");
Long start2 = System.currentTimeMillis();
result = predictor.predict(image);
Long end2 = System.currentTimeMillis();
System.out.println("推理耗时:"+(end2-start2)+"ms");
} catch (TranslateException e) {
e.printStackTrace();
}
}
System.out.println("3(结果)");
System.out.println(result);
System.out.println("4(绘制)");
Long start3 = System.currentTimeMillis();
image.drawBoundingBoxes(result);
Long end3 = System.currentTimeMillis();
System.out.println("绘制耗时:"+(end3-start3)+"ms");
System.out.println("5(image2Frame)");
Long start4 = System.currentTimeMillis();
Mat mat = image2Mat(image);
drawRect(mat,points);
Frame frameout = mat2Frame(mat);
Long end4 = System.currentTimeMillis();
System.out.println("image2Frame耗时:"+(end4-start4)+"ms");
return frameout;
}
private static Image frame2Image(Frame frame){
BufferedImage temp = Java2DFrameUtils.toBufferedImage (frame);
Image image = BufferedImageFactory.getInstance().fromImage(temp);
return image;
}
private static Frame image2Frame(Image image){
BufferedImage temp = (BufferedImage) image.getWrappedImage();
Frame frame = Java2DFrameUtils.toFrame(temp);
return frame;
}
private static Mat image2Mat(Image image){
BufferedImage temp = (BufferedImage) image.getWrappedImage();
Mat mat = Java2DFrameUtils.toMat(temp);
return mat;
}
private static Frame mat2Frame(Mat mat){
Frame frame = Java2DFrameUtils.toFrame(mat);
return frame;
}
private static void drawRect(Mat mat,org.opencv.core.Point[] points){
OpenCVFrameConverter.ToMat converter1 = new OpenCVFrameConverter.ToMat();
OpenCVFrameConverter.ToOrgOpenCvCoreMat converter2 = new OpenCVFrameConverter.ToOrgOpenCvCoreMat();
org.opencv.core.Mat cvmat = converter2.convert(converter1.convert(mat));
MatOfPoint ps = new MatOfPoint();
ps.fromArray(points);
Scalar scalar = new Scalar(255,0,255);
Imgproc.polylines(cvmat, Arrays.asList(ps), true, scalar, 5, Imgproc.LINE_8);
}
}
2.转换器文件
package com.xxx.onnx;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
public class YoloV5RelativeTranslator implements Translator<Image, DetectedObjects> {
private final Translator<Image, DetectedObjects> delegated;
private final Integer width;
private final Integer height;
public YoloV5RelativeTranslator(Translator<Image, DetectedObjects> translator) {
this.delegated = translator;
this.width = 640;
this.height = 640;
}
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws Exception {
DetectedObjects output = delegated.processOutput(ctx, list);
List<String> classList = new ArrayList<>();
List<Double> probList = new ArrayList<>();
List<BoundingBox> rectList = new ArrayList<>();
final List<DetectedObjects.DetectedObject> items = output.items();
items.forEach(item -> {
classList.add(item.getClassName());
probList.add(item.getProbability());
Rectangle b = item.getBoundingBox().getBounds();
Rectangle newBox = new Rectangle(b.getX() / width, b.getY() / height, b.getWidth() / width, b.getHeight() / height);
rectList.add(newBox);
});
return new DetectedObjects(classList, probList, rectList);
}
@Override
public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
return delegated.processInput(ctx,input);
}
@Override
public void prepare(TranslatorContext ctx) throws Exception {
delegated.prepare(ctx);
}
@Override
public Batchifier getBatchifier() {
return delegated.getBatchifier();
}
}
3.POM文件
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.lihao</groupId>
<artifactId>djl</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>Spring Boot Blank Project (from https://github.com/making/spring-boot-blank)</name>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.7.12</version>
</parent>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<start-class>com.lihao.App</start-class>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-thymeleaf</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>org</groupId>
<artifactId>opencv</artifactId>
<scope>system</scope>
<systemPath>${project.basedir}\src\main\resources\lib\opencv-480.jar</systemPath>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacv</artifactId>
<version>1.5.6</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>ffmpeg-platform</artifactId>
<version>4.4-1.5.6</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacv-platform</artifactId>
<version>1.5.6</version>
</dependency>
<dependency>
<groupId>ai.djl.serving</groupId>
<artifactId>wlm</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl.onnxruntime</groupId>
<artifactId>onnxruntime-engine</artifactId>
<version>0.23.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>win-x86_64</classifier>
<scope>runtime</scope>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.1-0.23.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
<build>
<finalName>djl</finalName>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<version>2.6.0</version>
</plugin>
</plugins>
</build>
</project>
4.demo运行结果