切片辅助超推理-sahi库-get_sliced_prediction源码简析

代码地址:https://github.com/obss/sahi

get_sliced_prediction源码中重要是理解nms或nmm。nms经常遇到不说。

 其中nmm即Non-Max Merging算法是最重要部分,它其实和nms比较类似。其具体原理我看到了一片博客,感觉讲的很好,如下:https://blog.roboflow.com/non-max-merging/

截取最重要部分如下:

Here are the steps Non-Max Merge takes:

  1. First, it sorts all detections by their confidence score, from highest to lowest.
  2. It then takes all pairs of detections and computes their IOU, checking how much the pair overlaps.
  3. From most confident to least, it will build groups of overlapping detections.
    1. It starts by creating a new group with the most confident non-grouped detection D1.
    2. Then, each non-grouped detection that overlaps with D1 by at least iou_threshold (specified by the user) is placed in the same group.
    3. By repeating these two steps, we end up with mutually exclusive groups, such as [[D1, D2, D4], [D3], [D5, D6]].
  4. Then merging begins. This is done with detection pairs (D1, D2) and is implementation-specific. In supervision we:
    1. Make a new bounding box xyxy to fit both D1 and D2.
    2. Make a new mask containing pixels where the masks of D1 or D2 were.
    3. Create a new confidence value, adding together the confidence of D1 and D2, normalized by their xyxy areas.
      New Conf = (Conf 1 * Area 1 + Conf 2 * Area 2) / (Area 1 + Area 2)
    4. Copy class_idtracker_id, and data from the Detection with the higher confidence.
  5. The prior step is done on detection pairs. How do we merge the whole group?
    1. Create an empty list for results.
    2. If there's only one detection in a group, add it to the results list.
    3. Otherwise, pick the first two detections, compute the IOU again, and if it's above the user-specified iou_threshold, pairwise merge it as outlined in the prior step.
      The resulting merged detection stays in the group as the new first element, and the group is shortened by 1. Continue pairwise merging while there are at least two elements in a group.
      Note that the IOU calculation makes the algorithm more costly but is required to prevent the merged detection from growing boundlessly.

sahi库中get_sliced_prediction函数如下。

def get_sliced_prediction(
    image,
    detection_model=None,
    slice_height: int = None,
    slice_width: int = None,
    overlap_height_ratio: float = 0.2,
    overlap_width_ratio: float = 0.2,
    perform_standard_pred: bool = True,
    postprocess_type: str = "GREEDYNMM",
    postprocess_match_metric: str = "IOS",
    postprocess_match_threshold: float = 0.5,
    postprocess_class_agnostic: bool = False,
    verbose: int = 1,
    merge_buffer_length: int = None,
    auto_slice_resolution: bool = True,
    slice_export_prefix: str = None,
    slice_dir: str = None,
) -> PredictionResult:
    """
    Function for slice image + get predicion for each slice + combine predictions in full image.

    Args:
        image: str or np.ndarray
            Location of image or numpy image matrix to slice
        detection_model: model.DetectionModel
        slice_height: int
            Height of each slice.  Defaults to ``None``.
        slice_width: int
            Width of each slice.  Defaults to ``None``.
        overlap_height_ratio: float
            Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window
            of size 512 yields an overlap of 102 pixels).
            Default to ``0.2``.
        overlap_width_ratio: float
            Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window
            of size 512 yields an overlap of 102 pixels).
            Default to ``0.2``.
        perform_standard_pred: bool
            Perform a standard prediction on top of sliced predictions to increase large object
            detection accuracy. Default: True.
        postprocess_type: str
            Type of the postprocess to be used after sliced inference while merging/eliminating predictions.
            Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.
        postprocess_match_metric: str
            Metric to be used during object prediction matching after sliced prediction.
            'IOU' for intersection over union, 'IOS' for intersection over smaller area.
        postprocess_match_threshold: float
            Sliced predictions having higher iou than postprocess_match_threshold will be
            postprocessed after sliced prediction.
        postprocess_class_agnostic: bool
            If True, postprocess will ignore category ids.
        verbose: int
            0: no print
            1: print number of slices (default)
            2: print number of slices and slice/prediction durations
        merge_buffer_length: int
            The length of buffer for slices to be used during sliced prediction, which is suitable for low memory.
            It may affect the AP if it is specified. The higher the amount, the closer results to the non-buffered.
            scenario. See [the discussion](https://github.com/obss/sahi/pull/445).
        auto_slice_resolution: bool
            if slice parameters (slice_height, slice_width) are not given,
            it enables automatically calculate these params from image resolution and orientation.
        slice_export_prefix: str
            Prefix for the exported slices. Defaults to None.
        slice_dir: str
            Directory to save the slices. Defaults to None.

    Returns:
        A Dict with fields:
            object_prediction_list: a list of sahi.prediction.ObjectPrediction
            durations_in_seconds: a dict containing elapsed times for profiling
    """

    # for profiling
    durations_in_seconds = dict()

    # currently only 1 batch supported
    num_batch = 1
    # create slices from full image
    time_start = time.time()
    #图像切片
    slice_image_result = slice_image(
        image=image,
        output_file_name=slice_export_prefix,
        output_dir=slice_dir,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
        auto_slice_resolution=auto_slice_resolution,
    )

    num_slices = len(slice_image_result)
    time_end = time.time() - time_start
    durations_in_seconds["slice"] = time_end

    # init match postprocess instance
    #支持"GREEDYNMM","NMM","NMS","LSNMS"后处理,GREEDYNMM为默认
    if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():
        raise ValueError(
            f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}"
        )
    elif postprocess_type == "UNIONMERGE":
        # deprecated in v0.9.3
        raise ValueError("'UNIONMERGE' postprocess_type is deprecated, use 'GREEDYNMM' instead.")
    #选择一个后处理模块
    postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]
    postprocess = postprocess_constructor(
        match_threshold=postprocess_match_threshold,
        match_metric=postprocess_match_metric,
        class_agnostic=postprocess_class_agnostic,
    )

    # create prediction input
    num_group = int(num_slices / num_batch)
    if verbose == 1 or verbose == 2:
        tqdm.write(f"Performing prediction on {num_slices} slices.")
    object_prediction_list = []
    # perform sliced prediction
    #下面对每一个切片图进行标准推理,只是需要把切片上位置坐标还原到最原始大图上
    for group_ind in range(num_group):
        # prepare batch (currently supports only 1 batch)
        image_list = []
        shift_amount_list = []
        for image_ind in range(num_batch):
            image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
            shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])
        # perform batch prediction
        # num_batch=1,image_list始终为一个切片图
        prediction_result = get_prediction(
            image=image_list[0],
            detection_model=detection_model,
            shift_amount=shift_amount_list[0],
            full_shape=[
                slice_image_result.original_image_height,
                slice_image_result.original_image_width,
            ],
        )
        # convert sliced predictions to full predictions
        for object_prediction in prediction_result.object_prediction_list:
            if object_prediction:  # if not empty
                object_prediction_list.append(object_prediction.get_shifted_object_prediction())

        # merge matching predictions during sliced prediction
        if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:
            object_prediction_list = postprocess(object_prediction_list)

    # perform standard prediction
    if num_slices > 1 and perform_standard_pred:
        prediction_result = get_prediction(
            image=image,
            detection_model=detection_model,
            shift_amount=[0, 0],
            full_shape=[
                slice_image_result.original_image_height,
                slice_image_result.original_image_width,
            ],
            postprocess=None,
        )
        object_prediction_list.extend(prediction_result.object_prediction_list)

    # merge matching predictions
    # 结果后处理:融合
    if len(object_prediction_list) > 1:
        object_prediction_list = postprocess(object_prediction_list)

    time_end = time.time() - time_start
    durations_in_seconds["prediction"] = time_end

    if verbose == 2:
        print(
            "Slicing performed in",
            durations_in_seconds["slice"],
            "seconds.",
        )
        print(
            "Prediction performed in",
            durations_in_seconds["prediction"],
            "seconds.",
        )

    return PredictionResult(
        image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
    )

1、get_sliced_prediction中的postprocess,默认为combine.py中的GreedyNMMPostprocess这个是核心函数,

class GreedyNMMPostprocess(PostprocessPredictions):
    def __call__(
            self,
            object_predictions: List[ObjectPrediction],
    ):
        object_prediction_list = ObjectPredictionList(object_predictions)
        # 转化为pytorch tensor
        object_predictions_as_torch = object_prediction_list.totensor()
        if self.class_agnostic:  # 一般不进入这里
            keep_to_merge_list = greedy_nmm(
                object_predictions_as_torch,
                match_threshold=self.match_threshold,
                match_metric=self.match_metric,
            )
        else:  # 一般进入这里,计算需要融合的目标
            keep_to_merge_list = batched_greedy_nmm(
                object_predictions_as_torch,
                match_threshold=self.match_threshold,
                match_metric=self.match_metric,
            )

        selected_object_predictions = []
        # 截取程序中keep_to_merge_list一小段:{34: [45, 53], 5: [29], 6: []}
        # 34: [45, 53]表示:第34个目标需要和45,53目标融合;5: [29]表示:第5个目标需要和29目标融合;6: []表示:第6个目标不需要和任何框融合;
        # 融合是逐步进行的,即第一次融合结果作为下一次融合输入
        for keep_ind, merge_ind_list in keep_to_merge_list.items():
            for merge_ind in merge_ind_list:
                #iou或ios大于指定阈值
                if has_match(
                        object_prediction_list[keep_ind].tolist(),
                        object_prediction_list[merge_ind].tolist(),
                        self.match_metric,
                        self.match_threshold,
                ):
                    #融合主函数,融合包坐标、分数、类别
                    object_prediction_list[keep_ind] = merge_object_prediction_pair(
                        object_prediction_list[keep_ind].tolist(), object_prediction_list[merge_ind].tolist()
                    )
            selected_object_predictions.append(object_prediction_list[keep_ind].tolist())

        return selected_object_predictions

需要注意的是sahi的score融合只是取其最大值,没有用博客中的计算公式

2、batched_greedy_nmm函数

def batched_greedy_nmm(
        object_predictions_as_tensor: torch.tensor,
        match_metric: str = "IOU",
        match_threshold: float = 0.5,
):
    """
    Apply greedy version of non-maximum merging per category to avoid detecting
    too many overlapping bounding boxes for a given object.
    Args:
        object_predictions_as_tensor: (tensor) The location preds for the image
            along with the class predscores, Shape: [num_boxes,5].
        match_metric: (str) IOU or IOS
        match_threshold: (float) The overlap thresh for
            match metric.
    Returns:
        keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices
        to keep to a list of prediction indices to be merged.
    """
    # x1,y1,x2,y2,score,clsid
    category_ids = object_predictions_as_tensor[:, 5].squeeze()
    keep_to_merge_list = {}
    for category_id in torch.unique(category_ids):
        curr_indices = torch.where(category_ids == category_id)[0]  # 全局索引
        # 当前类别下目标框可能需要合并的信息,局部索引
        curr_keep_to_merge_list = greedy_nmm(object_predictions_as_tensor[curr_indices], match_metric, match_threshold)
        curr_indices_list = curr_indices.tolist()
        # 局部索引转化为全局索引
        for curr_keep, curr_merge_list in curr_keep_to_merge_list.items():
            keep = curr_indices_list[curr_keep]
            merge_list = [curr_indices_list[curr_merge_ind] for curr_merge_ind in curr_merge_list]
            keep_to_merge_list[keep] = merge_list
    return keep_to_merge_list

3、greedy_nmm函数

def greedy_nmm(
        object_predictions_as_tensor: torch.tensor,
        match_metric: str = "IOU",
        match_threshold: float = 0.5,
):
    """
    Apply greedy version of non-maximum merging to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        object_predictions_as_tensor: (tensor) The location preds for the image
            along with the class predscores, Shape: [num_boxes,5].
        object_predictions_as_list: ObjectPredictionList Object prediction objects
            to be merged.
        match_metric: (str) IOU or IOS
        match_threshold: (float) The overlap thresh for
            match metric.
    Returns:
        keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices
        to keep to a list of prediction indices to be merged.
    """
    # 非极大值融合,和nms类似
    keep_to_merge_list = {}

    # we extract coordinates for every
    # prediction box present in P
    x1 = object_predictions_as_tensor[:, 0]
    y1 = object_predictions_as_tensor[:, 1]
    x2 = object_predictions_as_tensor[:, 2]
    y2 = object_predictions_as_tensor[:, 3]

    # we extract the confidence scores as well
    scores = object_predictions_as_tensor[:, 4]

    # calculate area of every block in P
    areas = (x2 - x1) * (y2 - y1)

    # sort the prediction boxes in P
    # according to their confidence scores
    order = scores.argsort()

    while len(order) > 0:
        # extract the index of the
        # prediction with highest score
        # we call this prediction S
        idx = order[-1]

        # remove S from P
        order = order[:-1]

        # sanity check
        if len(order) == 0:
            keep_to_merge_list[idx.tolist()] = []
            break

        # select coordinates of BBoxes according to
        # the indices in order
        xx1 = torch.index_select(x1, dim=0, index=order)
        xx2 = torch.index_select(x2, dim=0, index=order)
        yy1 = torch.index_select(y1, dim=0, index=order)
        yy2 = torch.index_select(y2, dim=0, index=order)

        # find the coordinates of the intersection boxes
        xx1 = torch.max(xx1, x1[idx])
        yy1 = torch.max(yy1, y1[idx])
        xx2 = torch.min(xx2, x2[idx])
        yy2 = torch.min(yy2, y2[idx])

        # find height and width of the intersection boxes
        w = xx2 - xx1
        h = yy2 - yy1

        # take max with 0.0 to avoid negative w and h
        # due to non-overlapping boxes
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)

        # find the intersection area
        inter = w * h

        # find the areas of BBoxes according the indices in order
        rem_areas = torch.index_select(areas, dim=0, index=order)

        if match_metric == "IOU":
            # find the union of every prediction T in P
            # with the prediction S
            # Note that areas[idx] represents area of S
            union = (rem_areas - inter) + areas[idx]
            # find the IoU of every prediction in P with S
            match_metric_value = inter / union

        elif match_metric == "IOS":
            # find the smaller area of every prediction T in P
            # with the prediction S
            # Note that areas[idx] represents area of S
            smaller = torch.min(rem_areas, areas[idx])
            # find the IoS of every prediction in P with S
            match_metric_value = inter / smaller
        else:
            raise ValueError()

        # keep the boxes with IoU/IoS less than thresh_iou
        mask = match_metric_value < match_threshold

        # matched_box_indices = order[(mask == False).nonzero().flatten()].flip(dims=(0,))

        ids = (mask == False).nonzero().flatten()
        matched_box_indices0 = order[ids]
        matched_box_indices = matched_box_indices0.flip(dims=(0,))#左右翻转,分数降序排列

        unmatched_indices = order[(mask == True).nonzero().flatten()]

        # update box pool
        order = unmatched_indices[scores[unmatched_indices].argsort()]

        # create keep_ind to merge_ind_list mapping
        keep_to_merge_list[idx.tolist()] = []

        for matched_box_ind in matched_box_indices.tolist():
            keep_to_merge_list[idx.tolist()].append(matched_box_ind)

    return keep_to_merge_list

SAHI是一种切片辅助推理框架,旨在帮助开发人员解决现实世界中的目标检测问题。它通过将图像分成多个切片来提高检测性能,从而克服了现实世界中的一些问题,例如目标尺寸变化,目标遮挡和目标密度变化等。SAHI的核心思想是将图像分成多个切片,然后对每个切片进行单独的检测,最后将检测结果合并起来得到最终的检测结果。这种方法可以提高检测性能,特别是对于小目标的检测效果更好。 下面是一个使用SAHI进行目标检测的Python代码示例: ```python import cv2 import numpy as np # 加载图像 img = cv2.imread('test.jpg') # 定义切片大小 slice_size = 512 # 获取图像大小 height, width, _ = img.shape # 计算切片数量 num_slices_h = int(np.ceil(height / slice_size)) num_slices_w = int(np.ceil(width / slice_size)) # 定义检测器 detector = cv2.dnn.readNetFromCaffe('deploy.prototxt', 'model.caffemodel') # 定义类别标签 class_labels = ['person', 'car', 'truck', 'bus'] # 定义检测结果列表 results = [] # 循环遍历每个切片 for i in range(num_slices_h): for j in range(num_slices_w): # 计算切片的坐标 x1 = j * slice_size y1 = i * slice_size x2 = min(x1 + slice_size, width) y2 = min(y1 + slice_size, height) # 提取切片 slice_img = img[y1:y2, x1:x2] # 构建输入blob blob = cv2.dnn.blobFromImage(slice_img, 1.0, (300, 300), (104.0, 177.0, 123.0)) # 进行检测 detector.setInput(blob) detections = detector.forward() # 解析检测结果 for k in range(detections.shape[2]): confidence = detections[0, 0, k, 2] class_id = int(detections[0, 0, k, 1]) # 如果置信度大于0.5,则将检测结果添加到列表中 if confidence > 0.5 and class_labels[class_id] == 'person': x = int(detections[0, 0, k, 3] * slice_size) + x1 y = int(detections[0, 0, k, 4] * slice_size) + y1 w = int(detections[0, 0, k, 5] * slice_size) - x h = int(detections[0, 0, k, 6] * slice_size) - y results.append((x, y, w, h)) # 在原始图像上绘制检测结果 for (x, y, w, h) in results: cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2) # 显示结果 cv2.imshow('result', img) cv2.waitKey(0) cv2.destroyAllWindows() ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值