sahi目标检测java实现

13 篇文章 0 订阅
11 篇文章 1 订阅

SAHI(Slicing Aided Hyper Inference)采用了切片辅助推理和微调技术,可提高小目标对象的检测精度。

1 图像进行切片分割

private static void sahiImg(Bitmap tBitmap, int sWh, int boxWh){
        int dImgW = tBitmap.getWidth();
        int dImgH = tBitmap.getHeight();
        int cNum = (int) Math.ceil((float) dImgW / sWh);
        int rNum = (int) Math.ceil((float) dImgH / sWh);
        Log.d("testWH", dImgW + "," + dImgH+ "," +rNum+ "," +cNum+ "," +sWh);
        for(int i=0; i<rNum; i++) {
            for (int j = 0; j < cNum; j++) {
                int bX = sWh * j;
                int bY = sWh * i;
                if( bX >=dImgW-boxWh){
                    bX = dImgW-boxWh;//break;
                    j = cNum -1;
                }
                if( bY >=dImgH-boxWh){
                    bY = dImgH-boxWh;
                    i = rNum -1;
                }
                Bitmap part1bmap = Bitmap.createBitmap(tBitmap, bX, bY, boxWh, boxWh);
                Log.d("testBmap",  bX + "," + bY+ "," +boxWh+ "," +i+ ",a" +j);
                if(CONST.decGpu==1){ isGpu = true; }
                YoloV5Ncnn.dObj[] yoloObj = CONST.yolov5ncnn.Detect(part1bmap, isGpu);

                for (YoloV5Ncnn.dObj dObj : yoloObj) {
                    Float[] tBoxArr = new Float[4];
                    tBoxArr[0] = bX + dObj.x;
                    tBoxArr[1] = bY + dObj.y;
                    tBoxArr[2] = bX + dObj.x + dObj.w;
                    tBoxArr[3] = bY + dObj.y + dObj.h;
                    boxList.add(tBoxArr);
                    Log.d("testXY",   i+ "|" + j + "|a" + dObj.prob + CONST.yPestArr[dObj.label]);
                    pestIdxList.add(dObj.label);//
                    confidList.add(dObj.prob);
                }
            }
        }
     }

2 nms非极大值抑制

    public static List<Integer> non_max_suppression(Float[][] box2Arr, List<Integer> pestIdxLs, List<Float> conFLs) {//single_class_
        if (box2Arr.length == 0)
            return null;
        List<Integer> confIdxLs = new ArrayList<>();//保存置信度大于CONF_THRESH的元素的下标
        List<Float> nConfLs = new ArrayList<>();//保存置信度大于CONF_THRESH的元素的值
        List<Integer> nPestIdxLs = new ArrayList<>();//pest name
        for (int i = 0; i < box2Arr.length; i++) {//confidences.size()
            float confVal = conFLs.get(i);
            Log.d("box",i +","+confVal);
            if (confVal > Float.parseFloat(CONST.conStr)){
                confIdxLs.add(i);
                nConfLs.add(confVal);
                nPestIdxLs.add(pestIdxLs.get(i));
            }
        }
        if (confIdxLs.isEmpty())
            return null;
        int aliveIdxSize = confIdxLs.size();

        List<Idxs> idxsList = new ArrayList<>();//将置信度与下标对应
        for (int i = 0; i < aliveIdxSize; i++) {
            //Idxs idxs = new Idxs(confIdxLs.get(i), nPestIdxLs.get(i), nConfLs.get(i));
            idxsList.add(new Idxs(confIdxLs.get(i), nPestIdxLs.get(i), nConfLs.get(i)));
        }
        Collections.sort(idxsList);//按score升序排列
        for (int i = 0; i < aliveIdxSize; i++) {
            Log.d("idxNum",i +","+idxsList.get(i).getIndex()+","+ idxsList.get(i).getPestIdx()+","+idxsList.get(i).getConfVal());
        }

        float ovXmin, ovYmin, ovXmax, ovYmax;
        float ovW, ovH, overArea, ovRatio;
        //取出得分最高的bbox,计算剩下的bbox与它的交并比iou,去掉大于iou_thresh的bbox
        List<Integer> pickList = new ArrayList<>();
        while (idxsList.size() > 0) {
            sleep(20);
            int lastN = idxsList.size() - 1;
            if(pickList.size() >= CONST.numDetect)//取置信度最高的NUM_DETECTIONS个结果
                break;
            int lastIdx = idxsList.get(lastN).getIndex();
            Log.d("idx",lastIdx+","+idxsList.get(lastN).getPestIdx()+","+idxsList.get(lastN).getConfVal()+"");
            float last_area = (box2Arr[lastIdx][2] -box2Arr[lastIdx][0]) * (box2Arr[lastIdx][3] -box2Arr[lastIdx][1]);//area=(xmax-xmin)*(ymax-ymin)
            pickList.add(lastIdx);
            List<Idxs> idxs_to_remove = new ArrayList<>();//交并比过大需要移除的bbox
            for (int i = 0; i < lastN; i++) {
                int iIdx = idxsList.get(i).getIndex();
                ovXmin = Math.max(box2Arr[lastIdx][0], box2Arr[iIdx][0]);
                ovYmin = Math.max(box2Arr[lastIdx][1], box2Arr[iIdx][1]);
                ovXmax = Math.min(box2Arr[lastIdx][2], box2Arr[iIdx][2]);
                ovYmax = Math.min(box2Arr[lastIdx][3], box2Arr[iIdx][3]);
                ovW = Math.max(0, ovXmax - ovXmin);
                ovH = Math.max(0, ovYmax - ovYmin);
                overArea = ovW * ovH;
                float i_area = (box2Arr[iIdx][2] -box2Arr[iIdx][0]) * (box2Arr[iIdx][3] -box2Arr[iIdx][1]);
                ovRatio = overArea / ( last_area + i_area - overArea);//IoU

                if (ovRatio > (float)CONST.iouThresh/100)
                    idxs_to_remove.add(idxsList.get(i));
            }
            idxs_to_remove.add(idxsList.get(lastN));
            Log.d("testIdx",idxs_to_remove.size()+"||"+idxsList.size());
            idxsList.removeAll(idxs_to_remove);
        }
        return pickList;
    }

3 检测图

    private static void drawImg(Bitmap mBmap,Float[][] bboxes, List<Integer> pestIdxLs, List<Float> conFLs,List<Integer> pickIdxLs) {
        copyBMap = mBmap.copy(Bitmap.Config.ARGB_8888, true);
        final int[] colors = new int[] {
                Color.rgb( 54,  67, 244),
                Color.rgb( 99,  30, 233),
                Color.rgb(176,  39, 156),
                Color.rgb(183,  58, 103),
                Color.rgb(181,  81,  63),
                Color.rgb(243, 150,  33),
                Color.rgb(244, 169,   3),
                Color.rgb(212, 188,   0),
                Color.rgb(136, 150,   0),
                Color.rgb( 80, 175,  76),
                Color.rgb( 74, 195, 139),
                Color.rgb( 57, 220, 205),
                Color.rgb( 59, 235, 255),
                Color.rgb(  7, 193, 255),
                Color.rgb(  0, 152, 255),
                Color.rgb( 34,  87, 255),
                Color.rgb( 72,  85, 121),
                Color.rgb(158, 158, 158),
                Color.rgb(139, 125,  96)
        };
        Canvas canvas = new Canvas(copyBMap);
        Paint paint = new Paint();
        paint.setStyle(Paint.Style.STROKE);
        paint.setStrokeWidth(4);
        Paint textbgpaint = new Paint();
        textbgpaint.setColor(Color.WHITE);
        textbgpaint.setStyle(Paint.Style.FILL);
        Paint textpaint = new Paint();
        textpaint.setColor(Color.BLACK);
        textpaint.setTextSize(26);
        textpaint.setTextAlign(Paint.Align.LEFT);

        for (int i = 0; i < pickIdxLs.size(); i++) { //if(yObj[i].prob>= Float.parseFloat(CONST.conStr)) {//高于置信限值 CONST.conStr
                paint.setColor(colors[i % 19]);
                float rectX1 = bboxes[pickIdxLs.get(i)][0];
                float rectY1 = bboxes[pickIdxLs.get(i)][1];
                float rectX2 = bboxes[pickIdxLs.get(i)][2];
                float rectY2 = bboxes[pickIdxLs.get(i)][3];
                canvas.drawRect(rectX1, rectY1, rectX2, rectY2, paint);
                {// draw filled text inside image
                    String text = CONST.yPestArr[pestIdxLs.get(pickIdxLs.get(i))] + " = " + String.format("%.1f", conFLs.get(pickIdxLs.get(i)) * 100) + "%";
                    float text_width = textpaint.measureText(text);
                    float text_height = -textpaint.ascent() + textpaint.descent();
                    float lX1 = bboxes[pickIdxLs.get(i)][0];
                    float lY1 = bboxes[pickIdxLs.get(i)][1] - text_height;
                    if (lY1 < 0)
                        lY1 = 0;
                    if (lX1 + text_width > copyBMap.getWidth())
                        lX1 = copyBMap.getWidth() - text_width;

                    canvas.drawRect(lX1, lY1, lX1 + text_width, lY1 + text_height, textbgpaint);
                    canvas.drawText(text, lX1, lY1 - textpaint.ascent(), textpaint);

                    CONST.rectStr += (int) Math.round(rectX1) + "," + (int) Math.round(rectY1) + "," + (int) Math.round(rectX2) + "," + (int) Math.round(rectY2) + ";";
                    CONST.detectStr += CONST.yPestArr[pestIdxLs.get(pickIdxLs.get(i))] + ","
                            + String.format("%.1f", conFLs.get(pickIdxLs.get(i)) * 100) + ";";
                }   //}
        }
        canvas.save();
        canvas.restore();
    }

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值