使用 java-onnx 部署 yolovx 实例分割

 处理 yolov5-seg 的实例分割,主要是计算mask信息,进行 output1和output2的矩阵相乘,然后对结果矩阵取sigmod以及二值化。可视化如下:

和目标检测一样,主要是模型输出有两个矩阵:

1 * 25200 * 117:这个就是25200个框的信息,每个框用117位的浮点数组。其中1是边框置信度、2~5是 xywh 也就是中心点坐标和宽高、5~85 就是80个类别的概率,85~117就是32位的 mask信息。

1 * 32 * 160 * 160:这就是在原始模型新加的proto小型CNN,和前面 32 mask信息做矩阵乘法就得到轮廓信息

和目标检测的区别在于,需要将数据1的后32位,乘数据2的32*160*160矩阵来得到目标的 mask 信息。参考 yolov5 源码下 general.py 中的 process_mask 函数,流程如下:

将 output1 的 32*160*160 也就是 float[][][] 转换成二维矩阵 32*25600 的 float[][],相当于把160*160展平,得到矩阵1,是 32行、25600列,

将 output2 的每个检测框的 117 取最后32位,转为一维矩阵 float[],得到矩阵2,是1行,32列。

然后将矩阵2和矩阵1乘法,1*32 乘 32*25600 ,得到矩阵3,是 1*25600。

将矩阵3再转换为二位矩阵 160*160,再进行插值扩展到原始图像大小得到矩阵4。

再对矩阵4每个元素进行sigmod以及二值化到0/2。根据目标检测框标注、就可以看到掩膜信息。

 后续再对截取的掩膜信息在原图上面标注轮廓信息即可。下面是参考 yolov5 源码下 general.py 中的 process_mask 函数的 java 版本,实现掩膜可视化:

// 参考 general.py 中的 process_mask 函数
public static void generateMaskInfo(List<Detection> detections,float[][][][] proto,int width,int height){
    // 32 * 160 * 160 这个是mask原型 c h w
    float[][][] maskSrc = proto[0];
    // 转为二维矩阵也就是 32 * 25600,也就是 32 行 25600 列,相当于把 160*160展平
    float[][] flattenedData = floatArray2floatArray(maskSrc);
    // 再转为矩阵
    RealMatrix m1 = MatrixUtils.createRealMatrix(floatArray2doubleArray(flattenedData));
    // 每个目标框
    detections.stream().forEach(detection -> {
        // 32 这个是mask 掩膜系数,也就是权重,转为矩阵
        float[] maskWeight = detection.getMaskWeight();
        // 作为一个行向量存储在m1中,也就是 1 行 32 列
        RealMatrix m2 = MatrixUtils.createRowRealMatrix(floatArray2doubleArray(maskWeight));
        // 矩阵乘法 1*32 乘 32*25600 得到 1*25600
        RealMatrix m3 = m2.multiply(m1);
        // 再将 1*25600 转回 160*160 也就是一个缩小的掩膜图
        RealMatrix m4 = transfer_25600_To_160_160(m3);
        // 对每个元素求sigmod限制到0~1,后续根据阈值进行二值化
        RealMatrix m5 = getSigmod(m4);
        // 将160*160上采样到图片原始尺寸
        RealMatrix m6 = resizeRealMatrix(m5,height,width);
        // 目标在原始图片上的xyxy
        showMatrixWithBox(m6,detection.getSrcXYXY()[0],detection.getSrcXYXY()[1],detection.getSrcXYXY()[2],detection.getSrcXYXY()[3]);
    });
}
public static float[][] floatArray2floatArray(float[][][] data){
    float[][] flattenedData = new float[data.length][data[0].length * data[0][0].length];
    for (int i = 0; i < data.length; i++) {
        float[][] slice = data[i];
        for (int j = 0; j < slice.length; j++) {
            System.arraycopy(slice[j], 0, flattenedData[i], j * slice[j].length, slice[j].length);
        }
    }
    return flattenedData;
}
public static double[] floatArray2doubleArray(float[] data){
    double[] maskDouble = new double[data.length];
    for (int j = 0; j < data.length; j++) {
        maskDouble[j] = (double) data[j];
    }
    return maskDouble;
}
public static double[][] floatArray2doubleArray(float[][] data){
    double[][] maskDouble = new double[data.length][data[0].length];
    for (int i = 0; i < data.length; i++) {
        for(int j=0; j<data[0].length;j++){
            maskDouble[i][j] = data[i][j];
        }
    }
    return maskDouble;
}
// 再将 1*25600 转回 160*160
public static RealMatrix transfer_25600_To_160_160(RealMatrix data){
    RealMatrix res = new Array2DRowRealMatrix(160, 160);
    for (int i = 0; i < 160; i++) {
        for (int j = 0; j < 160; j++) {
            int index = i * 160 + j;
            double value = data.getEntry(0, index);
            res.setEntry(i, j, value);
        }
    }
    return res;
}
    public static RealMatrix resizeRealMatrix(RealMatrix matrix, int newRows, int newCols) {
        int rows = matrix.getRowDimension();
        int cols = matrix.getColumnDimension();
        RealMatrix resizedMatrix = MatrixUtils.createRealMatrix(newRows, newCols);
        for (int i = 0; i < newRows; i++) {
            for (int j = 0; j < newCols; j++) {
                int origI = (int) Math.floor(i * rows / newRows);
                int origJ = (int) Math.floor(j * cols / newCols);
                double d = matrix.getEntry(origI, origJ);
//                if(d>=maskThreshold){
//                    d = 1;
//                }else{
//                    d = 0;
//                }
                resizedMatrix.setEntry(i, j, d);
            }
        }
        return resizedMatrix;
    }
// 弹窗显示一个 showMatrix 并画框
public static void showMatrixWithBox(RealMatrix matrix,float xmin,float ymin,float xmax,float ymax){
    // 转换 RealMatrix to BufferedImage
    int numRows = matrix.getRowDimension();
    int numCols = matrix.getColumnDimension();
    BufferedImage image = new BufferedImage(numCols, numRows, BufferedImage.TYPE_INT_RGB);
    for (int i = 0; i < numRows; i++) {
        for (int j = 0; j < numCols; j++) {
            double value = matrix.getEntry(i, j);
            int grayValue = (int) Math.round(value * 255.0);
            grayValue = Math.min(grayValue, 255);
            grayValue = Math.max(grayValue, 0);
            int pixelValue = (grayValue << 16) | (grayValue << 8) | grayValue;
            image.setRGB(j, i, pixelValue);
        }
    }
    // image 上画框
    Graphics2D graph = image.createGraphics();
    graph.setStroke(new BasicStroke(3));// 线粗细
    graph.setColor(Color.RED);
    // 画矩形
    graph.drawRect(
            Float.valueOf(xmin).intValue(),
            Float.valueOf(ymin).intValue(),
            Float.valueOf(xmax-xmin).intValue(),
            Float.valueOf(ymax-ymin).intValue());
    // 提交画框
    graph.dispose();
    // 弹窗显示
    JFrame frame = new JFrame("Image Dialog");
    frame.setSize(image.getWidth(), image.getHeight());
    JLabel label = new JLabel(new ImageIcon(image));
    frame.getContentPane().add(label);
    frame.setVisible(true);
    frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE)
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

0x13

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值