DeepJavaLibrary(DJL)框架的使用:在java上使用AlphaPose完成实时多人姿态估计任务

首先,这里是完整的项目代码

实现环境

  • java 1.8
  • djl 1.12.0
  • opencv-java 4.5.1(maven安装了javacv,暂时可以不用考虑手动配置)

实现原理

DJL框架能帮我们做什么

DJL框架为我们提供了在java上实现多种推理引擎的适配,因此我们在导出libtorch、onnx、mxnet等格式的模型后可以很容易的在java上完成推理任务。
DJL不能为我们完成比较复杂的后期计算任务,因此我们可能需要以来opencv框架或者DJL内置的NDArray对象来完成推理以外的计算。

Alphapose

Alphapose的推理模型主要由两部分组成

  1. yolo目标检测器
  2. alphapose自己的单人姿态估计网络(SPPE)
    两个模型中间,alphapose使用了仿射变换来对yolo的检测结果进行scale,并用对应的逆操作对SPPE的输出进行坐标系还原。

实现步骤

一、导出alphapose模型

导出yolov5

导出yolov5的方法网上很多,可以参考我的博客

导出单人姿态估计网络

可以参考这篇博客

二、实现Translator

使用YoloTranslator

可以参考我的博客(还是刚刚那个😅)

实现SPPETranslator

Translator是DJL为我们提供的一个推理模板,我们可以重写模板内的方法,完成推理以外数据预处理和后续处理工作

1.我们定义SPPETranslator
public class SPPETranslator extends BasePairTranslator<Mat, Rectangle, Joints> {
...
}

BasePairTranslator<I,P,O>是我自己封装的一个类,I代表输入类型,O代表输出类型,P是由于我们的输入包含两个,分别是原图和yolo输出的边界框。

2.数据预处理

这里用ctx.getAttachmen来保存仿射变换得到的一个新的尺寸,用于之后还原响应的坐标。
之所以cropped_bboxes是使用队列是为了能够适用动态batch的推理。

    @Override
    public NDList processInput(TranslatorContext ctx, Pair<Mat, Rectangle> input) throws Exception {
        Mat frame = input.getKey().clone();
        Rectangle bbox = input.getValue();
        int x = (int) Math.max(0, bbox.getX());
        int y = (int) Math.max(0, bbox.getY());
        int w = Math.min(frame.width(), (int) (x + bbox.getWidth())) - x;
        int h = Math.min(frame.height(), (int) (y + bbox.getHeight())) - y;
        Rectangle croppedBBox = CVUtils.scale(frame, x, y, w, h);


        Queue cropped_bboxes = (Queue) ctx.getAttachment("cropped_bboxes");
        if (cropped_bboxes == null) {
            cropped_bboxes = new LinkedList<>();
            ctx.setAttachment("cropped_bboxes", cropped_bboxes);
        }
        cropped_bboxes.add(croppedBBox);

        NDArray array = ImageUtils.mat2Image(frame).toNDArray(ctx.getNDManager(), Image.Flag.COLOR);

        return pipeline.transform(new NDList(array));
    }

另外,这一步十分关键的是要在java上调用opencv,实现图像的一个仿射变换。

    /**
     * Convert box coordinates to center and scale.
     * adapted from https://github.com/Microsoft/human-pose-estimation.pytorch
     *
     * @param mat
     * @param x
     * @param y
     * @param w
     * @param h
     * @return cropped box
     */
    public static Rectangle scale(Mat mat,
                                  double x, double y, double w, double h, double inputH, double inputW) {
        double inpCenterX = inputW / 2, inpCenterY = inputH / 2;
        double aspectRatio = inputW / inputH;

        double scaleMult = 1.25;
        // box_to_center_scale
        double centerX = x + 0.5 * w;
        double centerY = y + 0.5 * h;

        if (w > aspectRatio * h)
            h = w / aspectRatio;
        else if (w < aspectRatio * h)
            w = h * aspectRatio;
        double scaleX = w * scaleMult;
        double scaleY = h * scaleMult;

//        double rot = 0;
//        double sn = Math.sin(rot), cs = Math.cos(rot);
        // 获取仿射矩阵
        Mat trans = getAffineTransform(centerX, centerY, inputW, inputH, scaleX, false);
        // 仿射变化
        Imgproc.warpAffine(mat, mat, trans, new Size(inputW, inputH), Imgproc.INTER_LINEAR);

//        HighGui.imshow("person", mat);
        return new Rectangle(centerX - scaleX * 0.5, centerY - scaleY * 0.5, scaleX, scaleY);
    }

变换的效果大概像下面这样
在这里插入图片描述

3. 推理后的数据处理

这部分实际就是仿照alphapose源码进行重写,向量计算可以使用DJL框架内封装的NDArray,即使是在cpu上也会比使用for循环处理快。

    @Override
    public Joints processOutput(TranslatorContext ctx, NDList list) {

        NDArray pred = list.singletonOrThrow().toDevice(Device.cpu(), false);
        int numJoints = (int) pred.getShape().get(0);
        int height = (int) pred.getShape().get(1);
        int width = (int) pred.getShape().get(2);
        pred = Activation.sigmoid(pred.reshape(new Shape(1, numJoints, -1)));
        NDArray maxValues = pred.max(axis2, true).toType(DataType.FLOAT32, false);
        //normalized to probability
        NDArray heatmaps = pred
                .div(pred.sum(axis2, true))
                .reshape(1, numJoints, 1, height, width);

        // The edge probability
        NDArray hmX = heatmaps.sum(axis2n3);
        NDArray hmY = heatmaps.sum(axis2n4);
//        NDArray hmZ = heatmaps.sum(axis3n4);

        NDManager ndManager = NumpyUtils.ndManager;

        hmX = integralOp(hmX, ndManager);
        hmY = integralOp(hmY, ndManager);
//        hmZ = integralOp(hmZ, ndManager);

        NDArray coordX = hmX.sum(axis2, true);
        NDArray coordY = hmY.sum(axis2, true);

        NDArray predJoints = coordX
                .concat(coordY, 2)
                .reshape(1, numJoints, 2)
                .toType(DataType.FLOAT32, false);

        Rectangle bbox = (Rectangle) ((Queue) ctx.getAttachment("cropped_bboxes")).poll();
        double x = bbox.getX();
        double y = bbox.getY();
        double w = bbox.getWidth();
        double h = bbox.getHeight();
        double centerX = x + 0.5 * w, centerY = y + 0.5 * h;
        double scaleX = w;

        Mat trans = CVUtils.getAffineTransform(centerX, centerY, width, height, scaleX, true);
        NDArray ndTrans = CVUtils.transMat2NDArray(trans, ndManager).transpose(1, 0);

        predJoints = predJoints
                .concat(ONES_NDARRAY, 2);

        NDArray xys = predJoints.matMul(ndTrans);


        float[] flattened = xys.toFloatArray();
        float[] flattenedConfidence = maxValues.toFloatArray();

        List<Joint> joints = new ArrayList<>(numJoints);
        for (int i = 0; i < numJoints; ++i) {
            joints.add(new Joint(
                    flattened[2 * i],
                    flattened[2 * i + 1],
                    flattenedConfidence[i]));
        }
//        System.out.println(joints);
        return new Joints(joints);
    }

integralOp

    private static NDArray integralOp(NDArray hm, NDManager ndManager) {
        Shape hmShape = hm.getShape();
        NDArray arr = ndManager
                .arange(hmShape.get(hmShape
                        .dimension() - 1)).toType(DataType.FLOAT32, false);
        return hm.mul(arr);
    }

三、组合模型

这一步实际上就是输入图像到yolo模型,然后把yolo模型的输出输入到sppe,中间做一下简单的格式转换。

    static void detect(Mat frame,
                       YoloV5Detector detector,
                       ParallelPoseEstimator ppe) throws IOException, ModelNotFoundException, MalformedModelException, TranslateException {
        Image img = ImageUtils.mat2Image(frame);
        long startTime = System.currentTimeMillis();
        try {
            DetectedObjects results = detector.detect(img);
            List<DetectedObject> detectedObjects = new ArrayList<>(results.getNumberOfObjects());
            List<Rectangle> jointsInput = new ArrayList<>(results.getNumberOfObjects());
            for (DetectedObject obj : results.<DetectedObject>items()) {
                if ("person".equals(obj.getClassName())) {
                    detectedObjects.add(obj);
                    jointsInput.add(obj.getBoundingBox().getBounds());
                }
            }
            List<Joints> joints = ppe.infer(frame, jointsInput);
            for (DetectedObject obj : detectedObjects) {
                BoundingBox bbox = obj.getBoundingBox();
                Rectangle rectangle = bbox.getBounds();
                String showText = String.format("%s: %.2f", obj.getClassName(), obj.getProbability());
                rect.x = (int) rectangle.getX();
                rect.y = (int) rectangle.getY();
                rect.width = (int) rectangle.getWidth();
                rect.height = (int) rectangle.getHeight();
                // 画框
                Imgproc.rectangle(frame, rect, color, 2);
                //画名字
                Imgproc.putText(frame, showText,
                        new Point(rect.x, rect.y),
                        Imgproc.FONT_HERSHEY_COMPLEX,
                        rectangle.getWidth() / 200,
                        color);
            }
            for (Joints jointsItem : joints) {
                CVUtils.draw136KeypointsLight(frame, jointsItem);
            }
        } finally {

        }
        boolean showFPS = true;
        if (showFPS) {
            double fps = 1000.0 / (System.currentTimeMillis() - startTime);
            System.out.println(String.format("%.2f", fps));
            Imgproc.putText(frame, String.format("FPS: %.2f", fps),
                    new Point(0, 52),
                    Imgproc.FONT_HERSHEY_COMPLEX,
                    0.5,
                    ColorConst.COLOR_RED);
        }

    }

实现结果

效果演示

因为对sppe模型做了一点轻量化,所以有些点不是特别准,不过也够用了。
另外这边实际上也像alphapose一样做了一个简单的流水线处理,这部分代码可以在开头提到的项目中找到。
在这里插入图片描述

性能说明

这边做了一下性能测试,纯推理速度框架上基本和python上差不多,中间数据的处理尽量少使用for循环,使用NDArray或opencv的原生方法,否则性能可能会不如python。
上面的SPPE部分经过一些优化目前能比python上快1.6倍

可以使用以下代码来通过djl使用yolov5: ``` import ai.djl.Model; import ai.djl.basicmodelzoo.basiccv.YoloV5; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import ai.djl.translate.TranslatorFactory; import ai.djl.util.Utils; import java.io.IOException; import java.nio.file.Path; import java.nio.file.Paths; public class YoloV5Example { public static void main(String[] args) throws IOException, TranslateException { Path imageFile = Paths.get("path/to/image.jpg"); Image image = ImageFactory.getInstance().fromFile(imageFile); // Load the YOLOv5 model ZooModel<Image, DetectedObjects> model = YoloV5.builder().build(); try (Model model = model.getModel()) { model.setBlockRunner(new CudaBlockRunner()); model.load(); // Create a translator to convert input/output Translator<Image, DetectedObjects> translator = new YoloV5Translator(); // Run inference on the image DetectedObjects detections = model.predict(translator.translate(image)); System.out.println(detections); } } private static final class YoloV5Translator implements Translator<Image, DetectedObjects> { @Override public Batchifier getBatchifier() { // YOLOv5 only supports batch size 1 return Batchifier.STACK; } @Override public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws TranslateException { // Convert the output to DetectedObjects // ... } @Override public NDList processInput(TranslatorContext ctx, Image input) throws TranslateException { // Convert the input to NDList // ... } } public static final class Factory implements TranslatorFactory<Image, DetectedObjects> { @Override public Translator<Image, DetectedObjects> newInstance(Model model, Map<String, Object> arguments) { return new YoloV5Translator(); } } } ``` 这个代码使用DJL 框架来加载 YOLOv5 模型,并使用 CUDA 运行模型。它还定义了一个 `YoloV5Translator` 类来将输入和输出转换为适当的格式。最后,它使用 `model.predict()` 方法来运行推理,并将结果打印到控制台上。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

虹幺

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值