CenterNet 模型后处理 (C++和python代码)

该网络有三个部分组成  backbone(提取高级语义特征),上采样(恢复分辨率),head (三个卷积最终输出三个向量 )heatmap[B,C,H,W],wh [B,2,H,W],reg[B,2,h,w]
heatmap 存放的是目标中心点位置,整张图那个位置最接近1 ,代表那个位置,是目标的中心点。
wh  总共两个通道,存放的是目标在该点的目标框的长和宽,所以计算左上和右下点  x-w/2  
reg 总共两个通道,存放的是目标中心点的x,y偏移量,加上这个数值即可
wh reg 为什么就两个通道?
    因为目标的中心点取值范围不会超过整个图的size,即便是多个目标,也不会超过,可以想象成把多个通道的热力图,合起来,每个关键点都在size中
判断这个点属于那个类别,依据是,每个热度图取前十个最大值,比如说2个通道(2个类别),总共取了20个值,这个值除以10 ,就是类别。
所以后处理流程是
    float hm_chn0[Height][Width] = {0};
    float hm_chn1[Height][Width] = {0};


    float reg_chn0[Height][Width] = {0};
    float reg_chn1[Height][Width] = {0};


    float wh_chn0[Height][Width] = {0};
    float wh_chn1[Height][Width] = {0};


    // if (desc->astMI_OutputTensorDescs[0].eElmFormat == MI_IPU_FORMAT_FP32)


    // {
    //     //MI_FLOAT* data = (MI_FLOAT*)(OutputTensorVector->astArrayTensors[0].ptTensorData[0]);
    //     //memcpy(hmData, OutputTensorVector->astArrayTensors[0].ptTensorData[0],  sizeof(MI_FLOAT)*s32ClassCount );
    //     //memcpy(whData, OutputTensorVector->astArrayTensors[1].ptTensorData[0],  sizeof(MI_FLOAT)*s32ClassCount );
    //     //memcpy(regData, OutputTensorVector->astArrayTensors[2].ptTensorData[0], sizeof(MI_FLOAT)*s32ClassCount);
        
    // }
        MI_FLOAT* phmdata = (MI_FLOAT*)(OutputTensorVector->astArrayTensors[0].ptTensorData[0]);
        MI_FLOAT* pwhdata = (MI_FLOAT*)(OutputTensorVector->astArrayTensors[1].ptTensorData[0]);
        MI_FLOAT* pregdata = (MI_FLOAT*)(OutputTensorVector->astArrayTensors[2].ptTensorData[0]);
        for(unsigned int h = 0; h < Height; h++)
        {
            for(unsigned int w = 0; w < Width; w++)   
            {
               
                  //number+=1;


                   if(s32ClassCount%8==0)
                   {
                     //取模型推理出的数据,因为模型的输出做了数据对齐,这个要乘以8
                     hm_chn0[h][w] =*(phmdata+((h*Height+w)*8));
                     hm_chn1[h][w] =*(phmdata+((h*Height+w)*8+1));
                     //heatmap_chn0[h][w] = (unsigned char)hm_chn0[h][w]*255;
                     //heatmap_chn1[h][w] = (unsigned char)hm_chn1[h][w]*255;


                     wh_chn0[h][w] =*(pwhdata+((h*Height+w)*8));
                     wh_chn1[h][w] =*(pwhdata+((h*Height+w)*8+1));


                     reg_chn0[h][w] =*(pregdata+((h*Height+w)*8));
                     reg_chn1[h][w] =*(pregdata+((h*Height+w)*8+1));


                   }


                  
                 
            }


         }
     


     cout<<"hm[0] " << hm_chn0[0][0]<<" " << hm_chn0[0][1]<<endl;
     //cout<<"hm[0] " << heatmap_chn0[0][0]<<" " << heatmap_chn0[0][1]<<endl;
    // ofs << std::endl << "}" << std::endl << std::endl;
    // ofs.close();


    
    // Mat heatmap_image(Height, Width, CV_8UC1);
    // heatmap_image.data = heatmap[0];


    // imwrite("heat.jpg",heatmap_image);
    // heatmap_image.data = heatmap[1];
    // imwrite("heat_1.jpg",heatmap_image);
    Mat img_hm_0(Height, Width, CV_8UC1);
    Mat img_hm_1(Height, Width, CV_8UC1);
    Mat heatmap(Height, Width, CV_8UC1);
    Mat heatmap_1(Height, Width, CV_8UC1);
    Mat src(Height, Width, CV_32FC1);
    float a =0;
    for (int row = 0;row < Height;row++)
    {
        for (int col = 0;col < Width;col++)


        {
        
            //a = data[(200*row+col)*8] *255 ;
            a = hm_chn0[row][col]  *255;
            src.at<float>(row, col)= a ;
            
             
        }
        
    }
    src.convertTo(heatmap,CV_8U,10,0);
    imwrite("heatmap_0.jpg",heatmap);
     //
    src.convertTo(img_hm_0,CV_8U);
    src.convertTo(img_hm_1,CV_8U);


    unsigned char c  =  img_hm_0.at<uchar>(0, 0);
    unsigned char b  =  img_hm_0.at<uchar>(0, 1);


    cout<<" hm 0 unchar is" << (int)c <<' '<< (int)b << endl;
    
//  对heatmap做maxpool,我用膨胀来代替了
    cv::Mat mat;
    //Mat HmImg1(Height, Width, CV_8UC1);
    Mat Hmax1(Height, Width, CV_8UC1);


   // Mat HmImg2(Height, Width, CV_8UC1);
    Mat Hmax2(Height, Width, CV_8UC1);
//
    //HmImg1.data = heatmap[0];


    Mat hmimg1 = img_hm_0.clone();
    //HmImg2.data = heatmap[1];
    Mat hmimg2 = img_hm_0.clone();


    cv::Mat element = getStructuringElement(MORPH_RECT, Size(3, 3));
    dilate(hmimg1, Hmax1, element);
    dilate(hmimg2, Hmax2, element);
    ///
    //  for (unsigned int i = 0; i < onebuf; i++)
    // {
    //     if (hmimg1.data[i] != Hmax1.data[i])
    //     {
    //         hm[0][i] = 0;
    //     }
    //     if (hmimg2.data[i] != Hmax2.data[i])
    //     {
    //         hm[1][i] = 0;
    //     }


    // }
     for (int row = 0;row < Height;row++)
    {
        for (int col = 0;col < Width;col++)


        {
            if(hmimg1.at<uchar>(row,col) != Hmax1.at<uchar>(row,col))
            {
                hm_chn0[row][col] = 0;
            }
            
            if(hmimg2.at<uchar>(row,col) != Hmax2.at<uchar>(row,col))
            {
                hm_chn1[row][col] = 0;
            }
             
        }
        
    }
    // save
    for (int row = 0;row < Height;row++)
    {
        for (int col = 0;col < Width;col++)


        {
        
            //a = data[(200*row+col)*8] *255 ;
            a = hm_chn0[row][col]  *255;
            src.at<float>(row, col)= a ;
            
             
        }
        
    }
    src.convertTo(heatmap_1,CV_8U,10,0);
    imwrite("heatmap_1.jpg",heatmap);
    
    cout<<"ssssssssss"<<endl;
    float topk_scores[Chn][topN];
    int topk_inds[Chn][topN];
    int topk_ys[Chn][topN];
    int topk_xs[Chn][topN];
   // get topN
//   //在两张热度图中分别取最大的10个点
    //topk(hm[0], onebuf, topN, topk_scores[0], topk_inds[0]);
    //topk(hm[1], onebuf, topN, topk_scores[1], topk_inds[1]);
     //在二维数据中 &hm_chn0[1] 表示首行地址 hm_chn0+1 代表第一行地址,  hm_chn0[0] 和 *hm_chn0 表示首行元素地址,*(hm_chn1+1)第一行元素地址, *hm_chn0+1 表示首行下一个元素地址
    topk(hm_chn0[0], onebuf, topN, topk_scores[0], topk_inds[0]);
    topk(hm_chn1[0], onebuf, topN, topk_scores[1], topk_inds[1]);
//
//
    float scores[Chn * topN];
    int num = 0;
    //求这20个点的坐标
    for (unsigned int cl = 0; cl < Chn; cl++)
    {
        for (int n = 0; n < topN; n++)
       {   //坐标 = y*200+x.
               //存放的是每张图的位置 0-40000
            topk_inds[cl][n] = topk_inds[cl][n] % (onebuf);
            //y 0-200
            topk_ys[cl][n] = (int)(topk_inds[cl][n] / Height);
            topk_xs[cl][n] = (int)(topk_inds[cl][n] % Height);
            //取值范围0-1
            scores[num] = topk_scores[cl][n];
            num++;
            //cout <<"scores is  " << scores[num]<<endl;
       }
    }


    float topk_score[topN]; //输出数值最大的10个点的数值
    //取值范围 0-20 存放的是得分高的位置
    int topk_ind[topN];


//    //在20个点里取10个点
    topk(scores, Chn * topN, topN, topk_score, topk_ind);


    cout <<"Top scores is  " << topk_score[0]<<endl;


    int topk_clses[topN];
    int topk_y[topN];   
    int topk_x[topN];   
    int ind[topN];     


    for (unsigned int cl = 0; cl < topN; cl++)
    {
        //取值范围是0-1,就是这最大十个点,所属的类别
        topk_clses[cl] = (int)(topk_ind[cl] / topN);
        //10个值的 x,y坐标(中心坐标)相对于200*200来说
        topk_y[cl] = topk_ys[topk_clses[cl]][topk_ind[cl] % topN];
        topk_x[cl] = topk_xs[topk_clses[cl]][topk_ind[cl] % topN];
        //取值范围是0-H*w(40000)
        ind[cl] = topk_inds[topk_clses[cl]][topk_ind[cl] % topN];
    }
    //以上代码求出了在【112,112,2】上的10个最大值坐标和对应的值


    //对wh,reg处理,引入回归量
    float feat_reg[onebuf][Chn] = {0};
    float feat_wh[onebuf][Chn] = {0};


    float reg_view[topN][Chn];
    float wh_view[topN][Chn];


    float x[topN];
    float y[topN];
    float dets[topN][6];
    //这个循环如果并入maxpool的循环里,处理时间会多2ms,所以依然放在这边
    /
    // for (unsigned int i = 0; i < onebuf; i++)
    // {
    //     feat_reg[i][0] = reg[0][i];
    //     feat_reg[i][1] = reg[1][i];


    //     feat_wh[i][0] = wh[0][i];
    //     feat_wh[i][1] = wh[1][i];
    // }
    int inc = 0;
    for (int row = 0;row < Height;row++)
    {
        for (int col = 0;col < Width;col++)


        {
             feat_reg[inc][0] = reg_chn0[row][col];
             feat_reg[inc][1] = reg_chn1[row][col];
            
             feat_wh[inc][0] = wh_chn0[row][col];
             feat_wh[inc][1] = wh_chn1[row][col];
             inc+=1;
             
        }
        
    }


    /
    std::vector<int> ids;
    std::vector<cv::Rect> boxes;
    std::vector<float> confidences;
    cout<<"vvvvvvv"<<endl;
    for (int num = 0; num < topN; num++)
    {
        //reg 存放的是中心点的的偏移量
        reg_view[num][0] = feat_reg[ind[num]][0];
        reg_view[num][1] = feat_reg[ind[num]][1];
        //坐标增加回归量
        x[num] = topk_x[num] + reg_view[num][0];
        y[num] = topk_y[num] + reg_view[num][1];
        //读取前10个索引对应的wh
        wh_view[num][0] = feat_wh[ind[num]][0];
        wh_view[num][1] = feat_wh[ind[num]][1];


        //输出[10,6]的检测结果,其中10是置信top10,6是4(bboxes)+1(scores)+1(clses)
        dets[num][0] = (x[num] - (wh_view[num][0] / 2)) * 4;
        if(dets[num][0] < 0)
        {
            dets[num][0] = 0;
        }


        dets[num][1] = (y[num] - (wh_view[num][1] / 2)) * 4;
        if(dets[num][1] < 0)
        {
            dets[num][1] = 0;
        }


        dets[num][2] = (x[num] + (wh_view[num][0] / 2)) * 4;
        if(dets[num][2] < 0)
        {
            dets[num][2] = 0;
        }


        dets[num][3] = (y[num] + (wh_view[num][1] / 2)) * 4;
        if(dets[num][3] < 0)
        {
            dets[num][3] = 0;
        }
//
        dets[num][4] = topk_score[num];
        dets[num][5] = topk_clses[num];


        ids.push_back(dets[num][5]);
        confidences.push_back(dets[num][4]);
        boxes.emplace_back((int)dets[num][0], (int)dets[num][1], (int)(wh_view[num][0] * 4), (int)(wh_view[num][1] * 4));
    }
      //nms
    cout<<"gggggggggggggggg"<<endl;
    std::vector<int> indices;
    float score_threshold = 0.1;
    float nms_threshold = 0.1;


    NMSBoxes(boxes, confidences, score_threshold, nms_threshold, indices);


    //string filename=(string)(pstPreProcessedData->pImagePath);
    //cv 默认格式bgr,hwc
    cv::Mat img = cv::imread(image_path, -1);
    if (img.empty()) {
      std::cout << " error!  image don't exist!" << std::endl;
      exit(1);
    }
//    //网络大小
    int net_w, net_h;
    net_w = Width * 4;
    net_h = Height * 4;
    std::vector<float> Result_str;


    
    cout<<"indec size is  "<<(int)(indices.size()) <<endl;
    for (size_t i = 0; i < indices.size(); ++i)
    {
        //这4个点都是对于448*448图片来说的
        int idx = indices[i];
        cv::Rect box = boxes[idx];
        float xmin = static_cast<float>(box.x);
        float ymin = static_cast<float>(box.y);
        float xmax = xmin + static_cast<float>(box.width);
        float ymax = ymin + static_cast<float>(box.height);


        
        cout << xmin << " " << ymin << " " << xmax << " " << ymax << " " << confidences[idx] << " " << ids[idx] << " ";
        cout << endl;


        Result_str.push_back(xmin);
        Result_str.push_back(ymin);
        Result_str.push_back(xmax);
        Result_str.push_back(ymax);
        Result_str.push_back(confidences[idx]);
        Result_str.push_back(ids[idx]);
        //我们把这些点映射回原图
        if (img.cols > img.rows) //宽大于高
        {
            xmin = xmin * img.cols / net_w;
            xmax = xmax * img.cols / net_w;
            ymin = (ymin * img.cols / net_w) - ((img.cols - img.rows) / 2);
            ymax = (ymax * img.cols / net_w) - ((img.cols - img.rows) / 2);


            if(ymin < 0)
            {
                ymin = 0;
            }


            if(ymax > img.rows)
            {
                ymax = img.rows;
            }
        }
         else //高大于宽
        {
            ymin = ymin * img.rows / net_h;
            ymax = ymax * img.rows / net_h;
            xmin = (xmin * img.rows / net_h) - ((img.rows - img.cols) / 2);
            xmax = (xmax * img.rows / net_h) - ((img.rows - img.cols) / 2);


            if(xmin < 0)
            {
                xmin = 0;
            }


            if(xmax > img.cols)
            {
                xmax = img.cols;
            }
        }
        cout << xmin << " " << ymin << " " << xmax << " " << ymax << " " << confidences[idx] << " " << ids[idx] << " ";
        cout << endl;


        cv::rectangle(img, Point((int)(xmin),(int)(ymin)), Point((int)(xmax),(int)(ymax)), Scalar(0, 0, 255), 2);
        std::string save_out ="EEEEEE.jpg";
        imwrite(save_out, img);
     
     }

python版本

import cv2
import numpy as np
import torch
import torch.nn.functional as F

from data import CenterAffine


def gather_feature(fmap, index, mask=None, use_transform=False):
    if use_transform:
        # change a (N, C, H, W) tenor to (N, HxW, C) shape
        batch, channel = fmap.shape[:2]
        fmap = fmap.view(batch, channel, -1).permute((0, 2, 1)).contiguous()


    dim = fmap.size(-1)
    index = index.unsqueeze(len(index.shape)).expand(*index.shape, dim)
    fmap = fmap.gather(dim=1, index=index)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(fmap)
        fmap = fmap[mask]
        fmap = fmap.reshape(-1, dim)
    return fmap


class CenterNetDecoder(object):
    @staticmethod
    def decode(fmap, wh, reg=None, cat_spec_wh=False, K=100):
        r"""
        decode output feature map to detection results

        Args:
            fmap(Tensor): output feature map
            wh(Tensor): tensor that represents predicted width-height
            reg(Tensor): tensor that represens regression of center points
            cat_spec_wh(bool): whether apply gather on tensor `wh` or not
            K(int): topk value
        """
        batch, channel, height, width = fmap.shape

        fmap = CenterNetDecoder.pseudo_nms(fmap)

        scores, index, clses, ys, xs = CenterNetDecoder.topk_score(fmap, K=K)
        if reg is not None:
            reg = gather_feature(reg, index, use_transform=True)
            reg = reg.reshape(batch, K, 2)
            xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
            ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
        else:
            xs = xs.view(batch, K, 1) + 0.5
            ys = ys.view(batch, K, 1) + 0.5
        wh = gather_feature(wh, index, use_transform=True)

        if cat_spec_wh:
            wh = wh.view(batch, K, channel, 2)
            clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long()
            wh = wh.gather(2, clses_ind).reshape(batch, K, 2)
        else:
            wh = wh.reshape(batch, K, 2)

        clses = clses.reshape(batch, K, 1).float()
        scores = scores.reshape(batch, K, 1)

        half_w, half_h = wh[..., 0:1] / 2, wh[..., 1:2] / 2
        bboxes = torch.cat([xs - half_w, ys - half_h, xs + half_w, ys + half_h], dim=2)

        detections = (bboxes, scores, clses)

        return detections

    @staticmethod
    def transform_boxes(boxes, img_info, scale=1):
        r"""
        transform predicted boxes to target boxes

        Args:
            boxes(Tensor): torch Tensor with (Batch, N, 4) shape
            img_info(dict): dict contains all information of original image
            scale(float): used for multiscale testing
        """
        boxes = boxes.cpu().numpy().reshape(-1, 4)

        center = img_info["center"]
        size = img_info["size"]
        output_size = (img_info["width"], img_info["height"])
        src, dst = CenterAffine.generate_src_and_dst(center, size, output_size)
        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))

        coords = boxes.reshape(-1, 2)
        aug_coords = np.column_stack((coords, np.ones(coords.shape[0])))
        target_boxes = np.dot(aug_coords, trans.T).reshape(-1, 4)
        return target_boxes

    @staticmethod
    def pseudo_nms(fmap, pool_size=3):
        r"""
        apply max pooling to get the same effect of nms

        Args:
            fmap(Tensor): output tensor of previous step
            pool_size(int): size of max-pooling
        """
        pad = (pool_size - 1) // 2
        fmap_max = F.max_pool2d(fmap, pool_size, stride=1, padding=pad)
        keep = (fmap_max == fmap).float()
        return fmap * keep

    @staticmethod
    def topk_score(scores, K=40):
        """
        get top K point in score map
        """
        batch, channel, height, width = scores.shape

        # get topk score and its index in every H x W(channel dim) feature map
        topk_scores, topk_inds = torch.topk(scores.reshape(batch, channel, -1), K)

        topk_inds = topk_inds % (height * width)
        topk_ys = (topk_inds / width).int().float()
        topk_xs = (topk_inds % width).int().float()

        # get all topk in in a batch
        topk_score, index = torch.topk(topk_scores.reshape(batch, -1), K)
        # div by K because index is grouped by K(C x K shape)
        topk_clses = (index / K).int()
        topk_inds = gather_feature(topk_inds.view(batch, -1, 1), index).reshape(batch, K)
        topk_ys = gather_feature(topk_ys.reshape(batch, -1, 1), index).reshape(batch, K)
        topk_xs = gather_feature(topk_xs.reshape(batch, -1, 1), index).reshape(batch, K)

        return topk_score, topk_inds, topk_clses, topk_ys, topk_xs

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是一个简单的 YOLOv5 后处理 Python 代码示例: ```python import torch from numpy import random def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, classes=None, agnostic=False): """ 对 YOLOv5 预测的边界框进行非极大值抑制(NMS)。 prediction: YOLOv5 模型的预测输出,包含所有检测边界框的信息。 conf_thres: 置信度阈值,低于该值的边界框将被忽略。 iou_thres: IOU 阈值,高于该值的边界框将被视为重叠,并进行 NMS 处理。 classes: 只保留指定类别的边界框,如果为 None,则保留所有类别的边界框。 agnostic: 是否对类别进行融合,即不考虑类别信息。 返回值:经过 NMS 处理后的边界框信息。 """ # 从预测结果中提取边界框信息 box_corner = prediction[:, :, :4] box_wh = box_corner[:, :, 2:4] - box_corner[:, :, :2] box_area = box_wh[..., 0] * box_wh[..., 1] box_center = (box_corner[:, :, 2:4] + box_corner[:, :, :2]) / 2 # 根据置信度进行筛选 scores = prediction[:, :, 4] score_mask = scores > conf_thres # 如果没有符合条件的边界框则返回空列表 if score_mask.sum() == 0: return [] # 按照置信度排序 scores = scores[score_mask] boxes = torch.cat((box_center[score_mask], box_wh[score_mask]), 2) _, box_sort_idx = torch.sort(scores, descending=True) boxes = boxes[box_sort_idx] scores = scores[box_sort_idx] # 初始化 NMS 结果 keep_boxes = [] # 进行 NMS 处理 while boxes.shape[0] > 0: current_box = boxes[0] current_score = scores[0] keep_boxes.append(current_box) if boxes.shape[0] == 1: break iou = bbox_iou(current_box.unsqueeze(0), boxes[1:]) overlap_mask = iou > iou_thres if classes is not None and not agnostic: class_mask = boxes[:, 4] == classes overlap_mask = overlap_mask & class_mask.unsqueeze(1) boxes = boxes[~overlap_mask] scores = scores[~overlap_mask] return torch.stack(keep_boxes) def bbox_iou(box1, box2): """ 计算两个边界框之间的 IOU。 box1: 第一个边界框,可以是一个张量。 box2: 第二个边界框,可以是一个张量或一个张量列表。 返回值:IOU 值。 """ if box2.ndim == 1: box2 = box2.unsqueeze(0) b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] inter_x1 = torch.max(b1_x1, b2_x1) inter_y1 = torch.max(b1_y1, b2_y1) inter_x2 = torch.min(b1_x2, b2_x2) inter_y2 = torch.min(b1_y2, b2_y2) inter_area = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0) box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) iou = inter_area / (box1_area + box2_area - inter_area) return iou def scale_coords(coords, img_shape, pad_shape): """ 将边界框坐标从缩放后的图像坐标转换为原始图像坐标。 coords: 缩放后的边界框坐标,形状为 (n, 4),其中 n 是边界框的数量。 img_shape: 原始图像的形状,形状为 (height, width)。 pad_shape: 缩放后的图像的形状,形状为 (height, width)。 返回值:转换后的边界框坐标,形状为 (n, 4)。 """ gain = min(pad_shape[0] / img_shape[0], pad_shape[1] / img_shape[1]) pad = (pad_shape - img_shape * gain) / 2 coords[:, [0, 2]] -= pad[1] coords[:, [1, 3]] -= pad[0] coords[:, :4] /= gain return coords.round().astype(int) def postprocess(prediction, img_size, conf_thres=0.1, iou_thres=0.6): """ 对 YOLOv5 模型的预测结果进行后处理,包括 NMS 和转换坐标到原始图像坐标系。 prediction: YOLOv5 模型的预测输出,包含所有检测边界框的信息。 img_size: 原始图像的形状,形状为 (height, width)。 conf_thres: 置信度阈值,低于该值的边界框将被忽略。 iou_thres: IOU 阈值,高于该值的边界框将被视为重叠,并进行 NMS 处理。 返回值:经过 NMS 处理和坐标转换后的边界框信息。 """ # 进行 NMS 处理 prediction[..., :4] = scale_coords(prediction[..., :4], img_size, img_size) output = [non_max_suppression(pred, conf_thres, iou_thres) for pred in prediction] # 将边界框坐标转换为原始图像坐标系 for i, pred in enumerate(output): for obj in pred: obj[:4] = scale_coords(obj[:4].unsqueeze(0), img_size, img_size).squeeze() obj[5] = i return output ``` 这个代码示例包含了三个函数: - `non_max_suppression()`:对 YOLOv5 预测的边界框进行非极大值抑制(NMS)处理; - `bbox_iou()`:计算两个边界框之间的 IOU 值; - `postprocess()`:对 YOLOv5 模型的预测结果进行后处理,包括 NMS 和转换坐标到原始图像坐标系。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值