1. 源码
# 参考https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/utils.py#L356
def compute_intersection(set_1, set_2):
"""
计算anchor之间的交集
Args:
set_1: a tensor of dimensions (n1, 4), anchor表示成(xmin, ymin, xmax, ymax)
set_2: a tensor of dimensions (n2, 4), anchor表示成(xmin, ymin, xmax, ymax)
Returns:
intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, shape: (n1, n2)
"""
# PyTorch auto-broadcasts singleton dimensions
lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0)) # (n1, n2, 2)
upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0)) # (n1, n2, 2)
intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2)
return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2)
def compute_jaccard(set_1, set_2):
"""
计算anchor之间的Jaccard系数(IoU)
Args:
set_1: a tensor of dimensions (n1, 4), anchor表示成(xmin, ymin, xmax, ymax)
set_2: a tensor of dimensions (n2, 4), anchor表示成(xmin, ymin, xmax, ymax)
Returns:
Jaccard Overlap of each of the boxes in set 1 with respect to each of the boxes in set 2, shape: (n1, n2)
"""
# Find intersections
intersection = compute_intersection(set_1, set_2) # (n1, n2)
# Find areas of each box in both sets
areas_set_1 = (set_1[:, 2] - set_1[:, 0]) * (set_1[:, 3] - set_1[:, 1]) # (n1)
areas_set_2 = (set_2[:, 2] - set_2[:, 0]) * (set_2[:, 3] - set_2[:, 1]) # (n2)
# Find the union
# PyTorch auto-broadcasts singleton dimensions
union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection # (n1, n2)
return intersection / union # (n1, n2)
2. compute_intersection(set_1, set_2)函数分析
1. 传入的set1, set2是锚框坐标信息的矩阵,shape分别是 n1*4, n2*4
2. torch.max(tensor_a, tensor_b),当torch.max中参数是两个张量的时候,其作用是逐位比较,取最大,并将大的值放在结果矩阵的对应位置上,同时可以进行广播操作
3. set_1[:, :2].unsqueeze(1) 目的是取出set_1中每个锚框的xmin,ymin,结果矩阵的shape是n1 * 1 * 2
4. set_2[:, :2].unsqueeze(0) 目的是取出set_2中每个锚框的xmin,ymin,结果矩阵的shape是1 * n2 * 2
5. set_1[:, 2:].unsqueeze(1) 目的是取出set_1中每个锚框的xmax,ymax,结果矩阵的shape是n1 * 1 * 2
6. set_2[:, 2:].unsqueeze(0) 目的是取出set_2中每个锚框的xmax,ymax,结果矩阵的shape是1 * n2 * 2
7. 当对上面进行torch.max(a, b)时会先进行广播操作,将a和b两个张量的shape都变为n1 * n2 * 2
8. 对于7中比较的形式,我在下面用图表示出来