DJL调用目标检测模型检测rtsp视频流

1.获取目标检测模型

/**
 * @Author bjiang
 * @Description //TODO 获取目标检测模型
 * @Date 2022/1/14 16:03
 * @Version 1.0
 */
public class ModelCriteria {
    public Criteria<Image, DetectedObjects> getCriteria(){
        Map<String, Object> arguments = new ConcurrentHashMap<>();
        arguments.put("width", 640);
        arguments.put("height", 640);
        arguments.put("resize", true);//调整图片大小
        arguments.put("rescale", true);//图片值编程0-1之间
        Translator<Image, DetectedObjects> translator = YoloV5Translator.builder(arguments).optSynsetArtifactName("synset.txt").build();
        Criteria<Image, DetectedObjects> criteria =
                Criteria.builder()
                        .optApplication(Application.CV.INSTANCE_SEGMENTATION)
                        .setTypes(Image.class, DetectedObjects.class)
                        .optDevice(Device.cpu())
                        .optModelPath(Paths.get("D:\\work\\git\\model\\pt"))
                        .optModelName("best.torchscript.pt")
                        .optTranslator(translator)
                        .optProgress(new ProgressBar())
                        .optEngine("PyTorch")
                        .build();
        return criteria;
    }
}

2.获取rtsp视频流

 /**
     * @Author bjiang
     * @Description //TODO 获取rtsp视频流
     * @Date 15:47 2022/1/14
     * @Version 1.0
     * @Param [url]
     * @return void
     */
    public static FFmpegFrameGrabber getRtspByUrl(String url) {
        FFmpegFrameGrabber grabber = new FFmpegFrameGrabber(url);
        grabber.setFormat("rtsp");
        String byte2Base64= null;
        try {
            byte2Base64 = RSAUtil.getMsgByRsa("****");
        } catch (Exception e) {
            e.printStackTrace();
        }
        grabber.setOption("rtsp_transport", "tcp");//tcp传输协议
        grabber.setOption("appkey", "****");//海康视频 appkey
        grabber.setOption("secret", byte2Base64);//海康视频 secret
        grabber.setOption("playMode", "0");//初始播放模式:0-预览,1-回放
        grabber.setOption("port", "446");//综合安防管理平台端口,若启用HTTPS协议,默认443
        grabber.setOption("enableHTTPS", "1"); //是否启用HTTPS协议与综合安防管理平台交互,这里总是填1
        grabber.setOption("rtsp_flags", "prefer_tcp");
        grabber.setOption("stimeout", "3000000");
        try {
            grabber.start();
        } catch (FFmpegFrameGrabber.Exception e) {
            e.printStackTrace();
        }
        return grabber;
    }

3.增加窗口,实时显示视频流

/**
     * @Author bjiang
     * @Description //TODO 显示目标监测窗口
     * @Date 16:00 2022/1/14
     * @Version 1.0
     * @Param [grabber]
     * @return void
     */
    public static void showCanvasFrame(FFmpegFrameGrabber grabber) throws Exception {
        OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
        // 新建视频实时显示窗口
        CanvasFrame canvas = new CanvasFrame("目标监测"); 
        canvas.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        canvas.setVisible(true);
        canvas.setFocusable(true);
        if (canvas.isAlwaysOnTopSupported()) {
            canvas.setAlwaysOnTop(true);
        }
        Frame frame = null;
        Criteria<Image, DetectedObjects> criteria =new ModelCriteria().getCriteria();
        try (ZooModel model = ModelZoo.loadModel(criteria);
             Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
            for (;canvas.isVisible() && (frame = grabber.grabImage()) != null; ) {
                Mat img = converter.convert(frame);
                drawObject(predictor,img);
                //显示视频
                canvas.showImage(frame);
                //recorder.record(frame);
            }
        }
    }

4.在视频窗口绘制目标检测结果

/**
     * @Author bjiang
     * @Description //TODO 绘制检测信息
     * @Date 8:57 2022/1/17
     * @Version 1.0
     * @Param [predictor, img]
     * @return void
     */
    public static void drawObject(Predictor<Image, DetectedObjects> predictor,Mat img) throws Exception {

        BufferedImage buffImg = mat2BufferedImage(img);
        Image image = ImageFactory.getInstance().fromImage(buffImg);
        int imageWidth = image.getWidth();
        int imageHeight = image.getHeight();
        log.info("imageWidth={},imageHeight={}",imageWidth,imageHeight);
        DetectedObjects detections = predictor.predict(image);
        List<DetectedObjects.DetectedObject> items = detections.items();
        // 遍历检测物
        for (DetectedObjects.DetectedObject item : items) {
            BoundingBox box = item.getBoundingBox();
            Rectangle rectangle = box.getBounds();
            int x = (int) (rectangle.getX() * imageWidth);
            int y = (int) (rectangle.getY() * imageHeight);
            Rect face =
                    new Rect(
                            x,
                            y,
                            (int) (rectangle.getWidth() * imageWidth),
                            (int) (rectangle.getHeight() * imageHeight));
            // 绘制矩形区域
            rectangle(img, face, new Scalar(0, 0, 255, 1));

            int pos_x = Math.max(face.tl().x() - 10, 0);
            int pos_y = Math.max(face.tl().y() - 10, 0);
            String className=item.getClassName()+" " + item.getProbability();
            // 绘制检测物与可能性
            putText(
                    img,
                    className,
                    new Point(pos_x, pos_y),
                    FONT_HERSHEY_COMPLEX,
                    1.0,
                    new Scalar(0, 0, 255, 2.0));
        }
    }

附:工具类

  //生成秘钥对
    public static KeyPair getKeyPair() throws Exception {
        KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
        keyPairGenerator.initialize(2048);
        KeyPair keyPair = keyPairGenerator.generateKeyPair();
        return keyPair;
    }

    //获取公钥(Base64编码)
    public static String getPublicKey(KeyPair keyPair){
        PublicKey publicKey = keyPair.getPublic();
        byte[] bytes = publicKey.getEncoded();
        return byte2Base64(bytes);
    }
    

    //将Base64编码后的公钥转换成PublicKey对象
    public static PublicKey string2PublicKey(String pubStr) throws Exception{
        byte[] keyBytes = base642Byte(pubStr);
        X509EncodedKeySpec keySpec = new X509EncodedKeySpec(keyBytes);
        KeyFactory keyFactory = KeyFactory.getInstance("RSA");
        PublicKey publicKey = keyFactory.generatePublic(keySpec);
        return publicKey;
    }


    //公钥加密
    public static byte[] publicEncrypt(byte[] content, PublicKey publicKey) throws Exception{
        Cipher cipher = Cipher.getInstance("RSA");
        cipher.init(Cipher.ENCRYPT_MODE, publicKey);
        byte[] bytes = cipher.doFinal(content);
        return bytes;
    }
    

    //字节数组转Base64编码
    public static String byte2Base64(byte[] bytes){
        BASE64Encoder encoder = new BASE64Encoder();
        return encoder.encode(bytes);
    }

    //Base64编码转字节数组
    public static byte[] base642Byte(String base64Key) throws IOException{
        BASE64Decoder decoder = new BASE64Decoder();
        return decoder.decodeBuffer(base64Key);
    }
    /**
     * @Author bjiang
     * @Description //TODO 获取rsa加密数据
     * @Date 15:48 2022/1/14
     * @Version 1.0
     * @Param [message]
     * @return java.lang.String
     */
    public static String getMsgByRsa(String message) throws Exception {
        KeyPair keyPair = RSAUtil.getKeyPair();
        String publicKeyStr = RSAUtil.getPublicKey(keyPair);
        PublicKey publicKey = RSAUtil.string2PublicKey(publicKeyStr);
        byte[] publicEncrypt = RSAUtil.publicEncrypt(message.getBytes(), publicKey);
        String byte2Base64 = RSAUtil.byte2Base64(publicEncrypt);
        return byte2Base64;
    }
/**
     * 将mat转BufferedImage
     *
     */
    public static BufferedImage mat2BufferedImage(Mat matrix) {
        int cols = matrix.cols();
        int rows = matrix.rows();
        int elemSize = (int) matrix.elemSize();
        byte[] data = new byte[cols * rows * elemSize];

        matrix.data().get(data);

        int type = 0;
        switch (matrix.channels()) {
            case 1:
                type = BufferedImage.TYPE_BYTE_GRAY;
                break;
            case 3:
                type = BufferedImage.TYPE_3BYTE_BGR;
                byte b;
                for (int i = 0; i < data.length; i = i + 3) {
                    b = data[i];
                    data[i] = data[i + 2];
                    data[i + 2] = b;
                }
                break;
            default:
                return null;
        }
        BufferedImage image = new BufferedImage(cols, rows, type);
        image.getRaster().setDataElements(0, 0, cols, rows, data);
        return image;
    }

5.运行结果

    @RequestMapping("/detectionRtsp")
    public void detectionRtsp(String input){
        FFmpegFrameGrabber grabber=RtspUtils.getRtspByUrl(input);
        try {
            RtspUtils.showCanvasFrame(grabber);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

 <!--  javacv-->
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>javacv</artifactId>
            <version>1.5.6</version>
        </dependency>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>ffmpeg-platform</artifactId>
            <version>4.4-1.5.6</version>
        </dependency>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>javacv-platform</artifactId>
            <version>1.5.6</version>
        </dependency>

        <!--djl所需要的jar包-->


        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>${djl.version}</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <scope>runtime</scope>
            <version>1.9.1</version>
        </dependency>

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 18
    评论
调用 ONNX 模型可以使用 DJL 的 ONNX 模块。具体步骤如下: 1. 导入相关依赖 ```java import ai.djl.Model; import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.translate.Batchifier; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import ai.djl.translate.TranslatorFactory; import ai.djl.translate.TranslatorFactory.TranslatorFunction; import ai.djl.translate.TranslatorFactoryImpl; import ai.djl.translate.TranslatorUtils; import ai.djl.translate.batch.Batchifier; import ai.djl.translate.batch.DefaultBatchifier; import ai.djl.util.Utils; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; ``` 2. 加载 ONNX 模型 ```java Path modelDir = Path.of("path/to/model/dir"); Model model = Model.newInstance("MyModel"); model.setBlock(OnnxModelZoo.resnet50().getBlock()); model.load(modelDir, "model.onnx"); ``` 3. 创建 Translator ```java public class MyTranslator implements Translator<MyInput, MyOutput> { private List<String> classes; public MyTranslator(List<String> classes) { this.classes = classes; } @Override public Batchifier getBatchifier() { return DefaultBatchifier.INSTANCE; } @Override public MyOutput processOutput(TranslatorContext ctx, Output output) { NDArray array = output.getNDArray(); return new MyOutput(classes.get(array.argMax().getFloat())); } @Override public MyInput processInput(TranslatorContext ctx, Input input) { NDArray array = input.getNDArray(); return new MyInput(array); } } TranslatorFactory factory = new TranslatorFactoryImpl(); factory.registerTranslator(MyInput.class, MyOutput.class, new TranslatorFunction<MyInput, MyOutput>() { @Override public Translator<MyInput, MyOutput> apply(TranslatorContext ctx) { List<String> classes = null; try { classes = Files.readAllLines(Path.of("path/to/classes.txt")); } catch (IOException e) { e.printStackTrace(); } return new MyTranslator(classes); } }); ``` 4. 推理 ```java MyInput input = new MyInput(inputArray); MyOutput output = model.predict(input).get(0).getOutput(MyOutput.class); ``` 其中,`MyInput` 和 `MyOutput` 分别表示输入和输出的数据类型,需要根据实际情况进行定义。`classes.txt` 文件包含了模型输出的类别信息。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值