Pytorch 目标检测 IOU交并比函数分析

32 篇文章 0 订阅
31 篇文章 0 订阅

1. 源码

# 参考https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/utils.py#L356

def compute_intersection(set_1, set_2):
    """
    计算anchor之间的交集
    Args:
        set_1: a tensor of dimensions (n1, 4), anchor表示成(xmin, ymin, xmax, ymax)
        set_2: a tensor of dimensions (n2, 4), anchor表示成(xmin, ymin, xmax, ymax)
    Returns:
        intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, shape: (n1, n2)
    """
    # PyTorch auto-broadcasts singleton dimensions
    lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0))  # (n1, n2, 2)
    upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0))  # (n1, n2, 2)
    intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)  # (n1, n2, 2)
    return intersection_dims[:, :, 0] * intersection_dims[:, :, 1]  # (n1, n2)


def compute_jaccard(set_1, set_2):
    """
    计算anchor之间的Jaccard系数(IoU)
    Args:
        set_1: a tensor of dimensions (n1, 4), anchor表示成(xmin, ymin, xmax, ymax)
        set_2: a tensor of dimensions (n2, 4), anchor表示成(xmin, ymin, xmax, ymax)
    Returns:
        Jaccard Overlap of each of the boxes in set 1 with respect to each of the boxes in set 2, shape: (n1, n2)
    """
    # Find intersections
    intersection = compute_intersection(set_1, set_2)  # (n1, n2)

    # Find areas of each box in both sets
    areas_set_1 = (set_1[:, 2] - set_1[:, 0]) * (set_1[:, 3] - set_1[:, 1])  # (n1)
    areas_set_2 = (set_2[:, 2] - set_2[:, 0]) * (set_2[:, 3] - set_2[:, 1])  # (n2)

    # Find the union
    # PyTorch auto-broadcasts singleton dimensions
    union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection  # (n1, n2)

    return intersection / union  # (n1, n2)

2. compute_intersection(set_1, set_2)函数分析

1. 传入的set1, set2是锚框坐标信息的矩阵,shape分别是 n1*4, n2*4

2. torch.max(tensor_a, tensor_b),当torch.max中参数是两个张量的时候,其作用是逐位比较,取最大,并将大的值放在结果矩阵的对应位置上,同时可以进行广播操作

3. set_1[:, :2].unsqueeze(1) 目的是取出set_1中每个锚框的xmin,ymin,结果矩阵的shape是n1 * 1 * 2

4. set_2[:, :2].unsqueeze(0) 目的是取出set_2中每个锚框的xmin,ymin,结果矩阵的shape是1 * n2 * 2

5. set_1[:, 2:].unsqueeze(1) 目的是取出set_1中每个锚框的xmax,ymax,结果矩阵的shape是n1 * 1 * 2

6. set_2[:, 2:].unsqueeze(0) 目的是取出set_2中每个锚框的xmax,ymax,结果矩阵的shape是1 * n2 * 2

7. 当对上面进行torch.max(a, b)时会先进行广播操作,将a和b两个张量的shape都变为n1 * n2 * 2

8. 对于7中比较的形式,我在下面用图表示出来

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值