使用 java-onnx 部署 yolovx 目标检测

2023-9-16 更新:

后台有人问我视频推理,源码放下面,yolov5s 在 i9-12900k cpu 上单线程 45fps、多线程80fps,在 rtx3070 上单线程90fps,多线程150fps。

代码:

package tool.deeplearning;


import ai.onnxruntime.*;
import com.alibaba.fastjson.JSONObject;
import org.opencv.core.*;
import org.opencv.core.Point;
import org.opencv.imgproc.Imgproc;
import org.opencv.videoio.VideoCapture;
import org.opencv.videoio.Videoio;
import sun.font.FontDesignMetrics;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.List;

/**
*   @desc : video 视频 yolov5实时推理, 单线程(无缓冲队列)
*   @auth : tyf
*   @date : 2023-09-16  14:37:29
*/
public class yolov5_predict_video2 {


    // onnxruntime 环境
    public static OrtEnvironment env;
    public static OrtSession session;

    // 模型的类别信息,从权重读取
    public static List<String> clazzStr;

    // 模型的输入shape,从权重读取
    public static int count;//1 模型每次处理一张图片
    public static int channels;//3 模型通道数
    public static int netHeight;//640 模型高
    public static int netWidth;//640 模型宽

    // 检测框筛选阈值,参考 detect.py 中的设置
    public static float confThreshold = 0.65f;
    public static float nmsThreshold = 0.45f;

    // 标注颜色
    public static Scalar color = new Scalar(0, 0, 255);
    public static int tickness = 2;

    static {
        try {

            String weight = new File("").getCanonicalPath() + "\\model\\deeplearning\\yolov5\\yolov5s.onnx";
            env = OrtEnvironment.getEnvironment();
            session = env.createSession(weight, new OrtSession.SessionOptions());
            OnnxModelMetadata metadata = session.getMetadata();
            Map<String, NodeInfo> infoMap = session.getInputInfo();
            TensorInfo nodeInfo = (TensorInfo)infoMap.get("images").getInfo();
            String nameClass = metadata.getCustomMetadata().get("names");
            JSONObject names = JSONObject.parseObject(nameClass.replace("\"","\"\""));
            clazzStr = new ArrayList<>();
            names.entrySet().forEach(n->{
                clazzStr.add(String.valueOf(n.getValue()));
            });
            count = (int)nodeInfo.getShape()[0];//1 模型每次处理一张图片
            channels = (int)nodeInfo.getShape()[1];//3 模型通道数
            netHeight = (int)nodeInfo.getShape()[2];//640 模型高
            netWidth = (int)nodeInfo.getShape()[3];//640 模型宽
//            System.out.println("模型通道数="+channels+",网络输入高度="+netHeight+",网络输入宽度="+netWidth);
            System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
        }
        catch (Exception e){
            e.printStackTrace();
        }
    }

    // 目标框
    public static class  Detection{
        float x1;
        float y1;
        float x2;
        float y2;
        int type_max_index;
        float type_max_value;
        String type_max_name;
        public Detection(float[] box){
            // xywh
            float x = box[0];
            float y = box[1];
            float w = box[2];
            float h = box[3];
            // x1y1x2y2
            this.x1 = x - w * 0.5f;
            this.y1 = y - h * 0.5f;
            this.x2 = x + w * 0.5f;
            this.y2 = y + h * 0.5f;
            // 计算概率最大值index,第5位后面开始就是概率
            int max_index = 0;
            float max_value = 0;
            for (int i = 5; i < box.length; i++) {
                if (box[i] > max_value) {
                    max_value = box[i];
                    max_index = i;
                }
            }
            type_max_index = max_index - 5;
            type_max_value = max_value;
            type_max_name = clazzStr.get(type_max_index);
        }
        // 计算两个交并比
        private static double calculateIoU(Detection box1, Detection box2) {
            double x1 = Math.max(box1.x1, box2.x1);
            double y1 = Math.max(box1.y1, box2.y1);
            double x2 = Math.min(box1.x2, box2.x2);
            double y2 = Math.min(box1.y2, box2.y2);
            double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
            double box1Area = (box1.x2 - box1.x1 + 1) * (box1.y2 - box1.y1 + 1);
            double box2Area = (box2.x2 - box2.x1 + 1) * (box2.y2 - box2.y1 + 1);
            double unionArea = box1Area + box2Area - intersectionArea;
            return intersectionArea / unionArea;
        }
    }

    public static Mat resizeWithPadding(Mat src) {
        Mat dst = new Mat();
        int oldW = src.width();
        int oldH = src.height();
        double r = Math.min((double) netWidth / oldW, (double) netHeight / oldH);
        int newUnpadW = (int) Math.round(oldW * r);
        int newUnpadH = (int) Math.round(oldH * r);
        int dw = (Long.valueOf(netWidth).intValue() - newUnpadW) / 2;
        int dh = (Long.valueOf(netHeight).intValue() - newUnpadH) / 2;
        int top = (int) Math.round(dh - 0.1);
        int bottom = (int) Math.round(dh + 0.1);
        int left = (int) Math.round(dw - 0.1);
        int right = (int) Math.round(dw + 0.1);
        Imgproc.resize(src, dst, new Size(newUnpadW, newUnpadH));
        Core.copyMakeBorder(dst, dst, top, bottom, left, right, Core.BORDER_CONSTANT);
        return dst;
    }

    public static float[] hwc2chw(float[] src) {
        float[] chw = new float[src.length];
        int j = 0;
        for (int ch = 0; ch < 3; ++ch) {
            for (int i = ch; i < src.length; i += 3) {
                chw[j] = src[i];
                j++;
            }
        }
        return chw;
    }

    /**
    *   @desc : 推理并标注一帧
    *   @auth : tyf
    *   @date : 2023-09-16  15:48:09
    */
    public static long infer(Mat frame){

        long ts = System.currentTimeMillis();

        // 尺寸转换
        Mat input = resizeWithPadding(frame);
        // BGR -> RGB
        Imgproc.cvtColor(input, input, Imgproc.COLOR_BGR2RGB);
        //  归一化 0-255 转 0-1
        input.convertTo(input, CvType.CV_32FC1, 1. / 255);

        // 提起像素
        float[] hwc = new float[ channels * netWidth * netWidth];
        input.get(0, 0, hwc);
        float[] chw = hwc2chw(hwc);

        // 输入 tenser 并推理
        try {
            OnnxTensor tensor_input = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight});
            OrtSession.Result result = session.run(Collections.singletonMap("images", tensor_input));
            OnnxTensor tensor_output = (OnnxTensor)result.get(0);

            // 结果后处理 1,25200,117
            float[][] data = ((float[][][])tensor_output.getValue())[0];

            List<Detection> box_before_nsm = new ArrayList<>();
            List<Detection> box_after_nsm = new ArrayList<>();
            for(int i=0;i<data.length;i++){
                float[] obj = data[i];
                if(obj[4]>=confThreshold){
                    box_before_nsm.add(new Detection(obj));
                }
            }

            box_before_nsm.sort((o1, o2) -> Float.compare(o2.type_max_value,o1.type_max_value));
            while (!box_before_nsm.isEmpty()){
                Detection maxObj = box_before_nsm.get(0);
                box_after_nsm.add(maxObj);
                Iterator<Detection> it = box_before_nsm.iterator();
                while (it.hasNext()) {
                    Detection obj = it.next();
                    // 计算交并比
                    if(Detection.calculateIoU(maxObj,obj)>nmsThreshold){
                        it.remove();
                    }
                }
            }

            // 标注
            box_after_nsm.stream().forEach(n->{

                float x1 = n.x1;
                float y1 = n.y1;
                float x2 = n.x2;
                float y2 = n.y2;

                // 转为原始坐标
                float[] x1y1x2y2 = xy2xy(frame.width(),frame.height(),new float[]{x1,y1,x2,y2});
                x1 = x1y1x2y2[0];
                y1 = x1y1x2y2[1];
                x2 = x1y1x2y2[2];
                y2 = x1y1x2y2[3];

                // 类别和概率
                String clazz = n.type_max_name;
                String percent = String.format("%.2f", n.type_max_value*100)+"%";

                // 边框
                Imgproc.rectangle(frame, new Point(x1,y1), new Point(x2,y2), color, tickness);
                // 类别
                putText(frame,clazz+" "+percent,(int)x1,(int)y1-13-tickness,13,Color.BLACK,Color.RED);

            });
            tensor_input.close();
            tensor_output.close();
            input.release();
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }

        long te = System.currentTimeMillis();
        return te-ts;
    }



    // 原始图像 w1*h1
    // 模型图像 w2*h2
    // 待转换的坐标 x1y1x2y2
    public static float[] xy2xy(int w1,int h1,float[] x1y1x2y2){

        float gain = Math.min((float) netWidth / w1, (float) netHeight / h1);
        float padW = (netWidth - w1 * gain) * 0.5f;
        float padH = (netHeight - h1 * gain) * 0.5f;
        float xmin = x1y1x2y2[0];
        float ymin = x1y1x2y2[1];
        float xmax = x1y1x2y2[2];
        float ymax = x1y1x2y2[3];
        float xmin_ = Math.max(0, Math.min(w1 - 1, (xmin - padW) / gain));
        float ymin_ = Math.max(0, Math.min(h1 - 1, (ymin - padH) / gain));
        float xmax_ = Math.max(0, Math.min(w1 - 1, (xmax - padW) / gain));
        float ymax_ = Math.max(0, Math.min(h1 - 1, (ymax - padH) / gain));
        return new float[]{xmin_,ymin_,xmax_,ymax_};
    }


    // 绘制中文
    public static void putText(Mat src,String text,int x,int y,int charHeight,Color fontColor,Color backgroundColor){
        // 超出区域
        if(x<0||y<0){
            return;
        }
        // 获取字符串绘制的宽度
        Font font = new Font("Dialog", Font.BOLD, charHeight); // 设置字体和字号
        FontDesignMetrics metrics = FontDesignMetrics.getMetrics(font);
        int textWidth = metrics.stringWidth(text);
        // 创建一个java的空白图片并写入汉字
        BufferedImage image = new BufferedImage(textWidth, charHeight, BufferedImage.TYPE_3BYTE_BGR);
        Graphics2D g2d = image.createGraphics();
        g2d.setColor(backgroundColor); // 设置背景色为白色
        g2d.fillRect(0, 0, textWidth, charHeight); // 填充整个图片区域
        g2d.setFont(font); // 设置绘图字体
        g2d.setColor(fontColor); // 设置文本颜色为黑色
        g2d.drawString(text, 0, Double.valueOf(charHeight*0.85).intValue()); // 在图片上写入汉字
        g2d.dispose(); // 释放绘图资源
        // 转为 mat
        byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
        Mat mat = Mat.eye(image.getHeight(), image.getWidth(), CvType.CV_8UC3);
        mat.put(0, 0, pixels);
        // 在原始图片 src 的指定位置绘制 mat
        int colStart = x;
        int colEnd = x + mat.width();
        int rowStart = y;
        int rowEnd = y + mat.height();
        // 限制到图片区域内
        if(x>src.width()||y>src.height()){
            System.out.println("超出区域");
            return;
        }
        if(colEnd>src.width()){
            colEnd = src.width();
        }
        if(rowEnd>src.height()){
            rowEnd = src.height();
        }
        // 截取防止超出
        int sub_x = 0;
        int sub_y = 0;
        int sub_w = colEnd - colStart - 1;
        int sub_h =  rowEnd - rowStart - 1;
        if(sub_w<=0||sub_h<=0){
            System.out.println("无可显示距离");
            return;
        }
        // 创建一个矩形区域,从原始图片中截取
        Rect roi = Imgproc.boundingRect(new MatOfPoint(
                new Point(sub_x,sub_y),
                new Point(sub_x,sub_y+sub_h),
                new Point(sub_x+sub_w,sub_y),
                new Point(sub_x+sub_w,sub_y+sub_h)
        ));
        // 提取子图像
        Mat subImage = new Mat(mat, roi);
        subImage.copyTo(src.submat(rowStart,rowEnd,colStart,colEnd));
    }


    public static void main(String[] args) throws Exception{


        // 视频、rtsp流等
        String video = new File("").getCanonicalPath() + "\\model\\deeplearning\\yolov5\\1.mp4";

        // 创建VideoCapture对象并打开视频文件
        VideoCapture cap = new VideoCapture(video);

        // 设置想要的fps,每帧最大休眠时长
        int fps = 30;
        int interval = 1000/fps;

        // 视频帧宽高
        int width = (int) cap.get(Videoio.CAP_PROP_FRAME_WIDTH);
        int height = (int) cap.get(Videoio.CAP_PROP_FRAME_HEIGHT);

        // 用于显示的面板
        JFrame win = new JFrame("Image");
        JPanel panel = new JPanel();
        panel.setPreferredSize(new Dimension(width, height));
        win.getContentPane().add(panel);
        win.setVisible(true);
        win.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        win.setResizable(true);
        win.pack();

        // 用于显示的缓存,要修改图像直接修改 pixels 数组即可
        BufferedImage buffer = new BufferedImage(width, height, BufferedImage.TYPE_3BYTE_BGR);
        byte[] pixels = ((DataBufferByte) buffer.getRaster().getDataBuffer()).getData();

        // 创建一个Mat对象用于存储每一帧
        Mat frame = new Mat(height, width, CvType.CV_8UC3);
        int realFps = 0; // 真实fps
        int frameIndex = 0; // 当前处于第几帧
        double lastTime = 0; // 上次计算真实fps的时间
        int sleepTime = 0;// 渲染前休眠时间
        long lastDraw = 0;// 上次渲染时间
        long inferTime = 0;// 每fps个帧数的推理总耗时
        long inferTimeTotal = 0;// 每fps个帧数的推理总耗时

        // 处理每一帧
        while (cap.read(frame)) {

            // 在这里执行帧推理和标注,返回推理耗时
            long use = infer(frame);

            inferTimeTotal += use;

            // mat 写入到 pixels 像素缓存中,这里基本没有耗时
            frame.get(0,0,pixels);

            // 每fps个帧数计算一次总耗时,得到每帧耗时(真实fps)和每帧推理耗时
            if(frameIndex%fps==0){
                double thisTime = System.currentTimeMillis();
                // 真实fps
                realFps = (int)(1000/((thisTime - lastTime)/fps));
                // 计算真实推理耗时,并重置总耗时
                inferTime = inferTimeTotal / fps;
                inferTimeTotal = 0;
                // 保存为上一次统计时间
                lastTime = thisTime;
            }

            // 计算左上角显示的每帧间隔休眠时长
            sleepTime = 0;
            while(System.currentTimeMillis()-lastDraw<interval){
                try {
                    // 每次休眠1毫秒,直到下一次渲染时间距离上一次渲染时间保持稳定间隔
                    Thread.sleep(1);
                    sleepTime++;
                } catch (InterruptedException e1) {
                    e1.printStackTrace();
                }
            }
            lastDraw = System.currentTimeMillis();

            // 实时渲染,这里基本没有耗时,左上角显示fps和休眠时长
            Graphics2D g2 =(Graphics2D)buffer.getGraphics();
            g2.setColor(Color.BLACK);
            g2.drawString("FPS: "+realFps+"   "+"Sleep: "+sleepTime+"ms   "+"Infer: "+inferTime +"ms", 5, 15);
            panel.getGraphics().drawImage(buffer, 0, 0, panel);

            frameIndex++;

        }



    }


}


流程是加载权重,输入图片,将图片resize到模型输入的shape,并且除255归一化,像素点需要按照chw顺序排放,然后输入模型进行预测,预测完了之后解析 25200*85,根据置信度、nms等阈值进行过滤,然后将过滤剩下的框将坐标按照缩放比例转换到原始图像坐标系中,最后标注即可。

·

可以看到模型输出 1*25200*85,用85的浮点数据保存框图置信度、中心点坐标、宽高、每个类别的概率等可以自己进行解析。然后使用微软提供的 com.microsoft.onnxruntime 进行加载并推理:

开放神经网络交换ONNX(Open Neural Network Exchange)是一套表示深度神经网络模型的开放格式,由微软和Facebook于2017推出,然后迅速得到了各大厂商和框架的支持。通过短短几年的发展,已经成为表示深度学习模型的实际标准,并且通过ONNX-ML,可以支持传统非神经网络机器学习模型,大有一统整个AI模型交换标准。ONNX定义了一组与环境和平台无关的标准格式,为AI模型的互操作性提供了基础,使AI模型可以在不同框架和环境下交互使用。硬件和软件厂商可以基于ONNX标准优化模型性能,让所有兼容ONNX标准的框架受益,简单来说,ONNX就是模型转换的中间人。

使用下列脚本将模型转换为onnx:
python export.py --weights C:\Users\tangyufan\Desktop\custom\res\custom-01\weights\best.pt --include torchscript onnx --opset 16
注意:
使用 --include torchscript onnx  生成 onnx 文件,并指定 --opset 16,后续使用  com.microsoft.onnxruntime Java 库加载依赖时会要求 opset 版本,具体看使用的 ort 库的版本,
每个 ONNX 版本都支持不同的运算符集合,因此 opset_version 的值会影响哪些 PyTorch 运算符可以被导出到 ONNX 格式。如果模型中使用了 ONNX 运算符集合中不支持的运算符,那么将无法导出模型。此外,导出的模型也只能在支持相应 ONNX 版本的平台上运行。

完整代码如下:

package tool.yolo.onnxruntime;
 
import ai.onnxruntime.*;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.opencv.core.*;
import org.opencv.core.Point;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
 
/**
*   @desc : 使用 com.microsoft.onnxruntime 加载 yolov5 onnx 进行推理
*   @auth : tyf
*   @date : 2023-03-21  09:31:31
*/
public class predictTest {
    public static OrtEnvironment env;
    public static OrtSession session;
    public static JSONObject names;
    public static long count;
    public static long channels;
    public static long netHeight;
    public static long netWidth;
    public static float confThreshold = 0.25f;
    public static float nmsThreshold = 0.45f;
    static {
        String weight = "C:\\Users\\tyf\\Desktop\\yolov5s.onnx";
        try{
            env = OrtEnvironment.getEnvironment();
            session = env.createSession(weight, new OrtSession.SessionOptions());
            OnnxModelMetadata metadata = session.getMetadata();
            Map<String, NodeInfo> infoMap = session.getInputInfo();
            TensorInfo nodeInfo = (TensorInfo)infoMap.get("images").getInfo();
            String nameClass = metadata.getCustomMetadata().get("names");
            System.out.println("-------打印模型信息开始--------");
            System.out.println("getProducerName="+metadata.getProducerName());
            System.out.println("getGraphName="+metadata.getGraphName());
            System.out.println("getDescription="+metadata.getDescription());
            System.out.println("getDomain="+metadata.getDomain());
            System.out.println("getVersion="+metadata.getVersion());
            System.out.println("getCustomMetadata="+metadata.getCustomMetadata());
            System.out.println("getInputInfo="+infoMap);
            System.out.println("nodeInfo="+nodeInfo);
            System.out.println("-------打印模型信息结束--------");
            names = JSONObject.parseObject(nameClass.replace("\"","\"\""));
            System.out.println("类别信息:"+names);
            count = nodeInfo.getShape()[0];//1 模型每次处理一张图片
            channels = nodeInfo.getShape()[1];//3 模型通道数
            netHeight = nodeInfo.getShape()[2];//640 模型高
            netWidth = nodeInfo.getShape()[3];//640 模型宽
            System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
 
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }
    }

    public static Mat readImg(String path){
        Mat img = Imgcodecs.imread(path);
        return img;
    }
 

    public static Mat resizeWithPadding(Mat src) {
        Mat dst = new Mat();
        int oldW = src.width();
        int oldH = src.height();
        double r = Math.min((double) netWidth / oldW, (double) netHeight / oldH);
        int newUnpadW = (int) Math.round(oldW * r);
        int newUnpadH = (int) Math.round(oldH * r);
        int dw = (Long.valueOf(netWidth).intValue() - newUnpadW) / 2;
        int dh = (Long.valueOf(netHeight).intValue() - newUnpadH) / 2;
        int top = (int) Math.round(dh - 0.1);
        int bottom = (int) Math.round(dh + 0.1);
        int left = (int) Math.round(dw - 0.1);
        int right = (int) Math.round(dw + 0.1);
        Imgproc.resize(src, dst, new Size(newUnpadW, newUnpadH));
        Core.copyMakeBorder(dst, dst, top, bottom, left, right, Core.BORDER_CONSTANT);
        return dst;
    }

    public static OnnxTensor transferTensor(Mat dst){
        Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
        dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
        float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
        dst.get(0, 0, whc);
        float[] chw = whc2cwh(whc);
        OnnxTensor tensor = null;
        try {
            tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight});
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }
        return tensor;
    }
 
    public static float[] whc2cwh(float[] src) {
        float[] chw = new float[src.length];
        int j = 0;
        for (int ch = 0; ch < 3; ++ch) {
            for (int i = ch; i < src.length; i += 3) {
                chw[j] = src[i];
                j++;
            }
        }
        return chw;
    }
 
    public static int getMaxIndex(float[] array) {
        int maxIndex = 0;
        float maxVal = array[0];
        for (int i = 1; i < array.length; i++) {
            if (array[i] > maxVal) {
                maxVal = array[i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }
 
 
    public static float[] xywh2xyxy(float[] bbox) {
        float x = bbox[0];
        float y = bbox[1];
        float w = bbox[2];
        float h = bbox[3];
        float x1 = x - w * 0.5f;
        float y1 = y - h * 0.5f;
        float x2 = x + w * 0.5f;
        float y2 = y + h * 0.5f;
        return new float[]{
                x1 < 0 ? 0 : x1,
                y1 < 0 ? 0 : y1,
                x2 > netWidth ? netWidth:x2,
                y2 > netHeight? netHeight:y2};
    }
 
    public static JSONArray filterRec1(float[][] data){
        JSONArray recList = new JSONArray();
        for (float[] bbox : data){
            float[] xywh = new float[] {bbox[0],bbox[1],bbox[2],bbox[3]};
            float[] xyxy = xywh2xyxy(xywh);
            float confidence = bbox[4];
            float[] classInfo = Arrays.copyOfRange(bbox, 5, 85);
            int maxIndex = getMaxIndex(classInfo);
            float maxValue = classInfo[maxIndex];
            String maxClass = (String)names.get(Integer.valueOf(maxIndex));
            // 首先根据框图置信度粗选
            if(confidence>=confThreshold){
                JSONObject detect = new JSONObject();
                detect.put("name",maxClass);// 类别
                detect.put("percentage",maxValue);// 概率
                detect.put("xmin",xyxy[0]);
                detect.put("ymin",xyxy[1]);
                detect.put("xmax",xyxy[2]);
                detect.put("ymax",xyxy[3]);
                recList.add(detect);
            }
        }
        return recList;
    }
 
    public static JSONArray filterRec2(JSONArray data){
        JSONArray res = new JSONArray();
        data.sort(Comparator.comparing(obj->((JSONObject)obj).getString("percentage")).reversed());
        while (!data.isEmpty()){
            JSONObject max = data.getJSONObject(0);
            res.add(max);
            Iterator<Object> it = data.iterator();
            while (it.hasNext()) {
                JSONObject obj = (JSONObject)it.next();
                double iou = calculateIoU(max, obj);
                if (iou > nmsThreshold) {
                    it.remove();
                }
            }
        }
        return res;
    }
 
    private static double calculateIoU(JSONObject box1, JSONObject box2) {
        double x1 = Math.max(box1.getDouble("xmin"), box2.getDouble("xmin"));
        double y1 = Math.max(box1.getDouble("ymin"), box2.getDouble("ymin"));
        double x2 = Math.min(box1.getDouble("xmax"), box2.getDouble("xmax"));
        double y2 = Math.min(box1.getDouble("ymax"), box2.getDouble("ymax"));
        double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
        double box1Area = (box1.getDouble("xmax") - box1.getDouble("xmin") + 1) * (box1.getDouble("ymax") - box1.getDouble("ymin") + 1);
        double box2Area = (box2.getDouble("xmax") - box2.getDouble("xmin") + 1) * (box2.getDouble("ymax") - box2.getDouble("ymin") + 1);
        double unionArea = box1Area + box2Area - intersectionArea;
        return intersectionArea / unionArea;
    }
 
    public static JSONArray transferSrc2Dst(JSONArray data,int srcw,int srch){
        JSONArray res = new JSONArray();
        float gain = Math.min((float) netWidth / srcw, (float) netHeight / srch);
        float padW = (netWidth - srcw * gain) * 0.5f;
        float padH = (netHeight - srch * gain) * 0.5f;
        data.stream().forEach(n->{
            JSONObject obj = JSONObject.parseObject(n.toString());
            float xmin = obj.getFloat("xmin");
            float ymin = obj.getFloat("ymin");
            float xmax = obj.getFloat("xmax");
            float ymax = obj.getFloat("ymax");
            float xmin_ = Math.max(0, Math.min(srcw - 1, (xmin - padW) / gain));
            float ymin_ = Math.max(0, Math.min(srch - 1, (ymin - padH) / gain));
            float xmax_ = Math.max(0, Math.min(srcw - 1, (xmax - padW) / gain));
            float ymax_ = Math.max(0, Math.min(srch - 1, (ymax - padH) / gain));
            obj.put("xmin",xmin_);
            obj.put("ymin",ymin_);
            obj.put("xmax",xmax_);
            obj.put("ymax",ymax_);
            res.add(obj);
        });
        return res;
    }
    public static void pointBox(String pic,JSONArray box){
        if(box.size()==0){
            System.out.println("暂无识别目标");
            return;
        }
        try {        
            File imageFile = new File(pic);
            BufferedImage img = ImageIO.read(imageFile);
            Graphics2D graph = img.createGraphics();
            graph.setStroke(new BasicStroke(2));
            graph.setFont(new Font("Serif", Font.BOLD, 20));
            graph.setColor(Color.RED);
            box.stream().forEach(n->{
                JSONObject obj = JSONObject.parseObject(n.toString());
                String name = obj.getString("name");
                float percentage = obj.getFloat("percentage");
                float xmin = obj.getFloat("xmin");
                float ymin = obj.getFloat("ymin");
                float xmax = obj.getFloat("xmax");
                float ymax = obj.getFloat("ymax");
                float w = xmax - xmin;
                float h = ymax - ymin;
                graph.drawRect(
                        Float.valueOf(xmin).intValue(), 
                        Float.valueOf(ymin).intValue(),
                        Float.valueOf(w).intValue(),
                        Float.valueOf(h).intValue());
                DecimalFormat decimalFormat = new DecimalFormat("#.##");
                String percentString = decimalFormat.format(percentage);
                graph.drawString(name+" "+percentString, xmin-1, ymin-5);
            });
            graph.dispose();
            JFrame frame = new JFrame("Image Dialog");
            frame.setSize(img.getWidth(), img.getHeight());
            JLabel label = new JLabel(new ImageIcon(img));
            frame.getContentPane().add(label);
            frame.setVisible(true);
            frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }
    }
    public static void main(String[] args) throws Exception{
        String pic = "C:\\Users\\tyf\\Desktop\\img.png";
        Mat src = readImg(pic);
        int srcw = src.width();
        int srch = src.height();
        Mat dst = resizeWithPadding(src);
        OnnxTensor tensor = transferTensor(dst);
        OrtSession.Result result = session.run(Collections.singletonMap("images", tensor));
        OnnxTensor res = (OnnxTensor)result.get(0);
        float[][][] dataRes = (float[][][])res.getValue();
        float[][] data = dataRes[0];
        JSONArray srcRec = filterRec1(data);
        JSONArray srcRec2 = filterRec2(srcRec);
        JSONArray dstRec = transferSrc2Dst(srcRec2,srcw,srch);
        pointBox(pic,dstRec);
    }
}

// 实际上有两个依赖,前者只能cpu推理,后者可以使用cpu或gpu推理

// <dependency>

//       <groupId>com.microsoft.onnxruntime</groupId>

//       <artifactId>onnxruntime_gpu</artifactId>

//       <version>1.11.0</version>

//    </dependency>

// <dependency>

//       <groupId>com.microsoft.onnxruntime</groupId>

//       <artifactId>onnxruntime_gpu</artifactId>

//       <version>1.11.0</version>

//    </dependency>

通过下面的方式设置GPU:

    int gpuDeviceId = 0; // The GPU device ID to execute on

    var sessionOptions = new OrtSession.SessionOptions();

    sessionOptions.addCUDA(gpuDeviceId);

    var session = environment.createSession("model.onnx", sessionOptions);

其中deviceId通过cuda脚本查询,这里就是0:

  • 2
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 22
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

0x13

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值