文章目录
1. IOU
1.1 IOU
参考:知乎
自己没法直接理解多维度多个box之间的iou,索性一步步写下来。
pytoch 源代码
=
# IOU计算
# 假设box1维度为[N,4] box2维度为[M,4]
def iou(self, box1, box2):
N = box1.size(0)
M = box2.size(0)
lt = torch.max( # 左上角的点 注意:是inter的左上角的点,因此才要算MAX
box1[:, :2].unsqueeze(1).expand(N, M, 2), # [N,2]->[N,1,2]->[N,M,2]
box2[:, :2].unsqueeze(0).expand(N, M, 2), # [M,2]->[1,M,2]->[N,M,2]
)
rb = torch.min( # 右下角的点 注意:是inter的右下角的点,因此才要算MIN
box1[:, 2:].unsqueeze(1).expand(N, M, 2),
box2[:, 2:].unsqueeze(0).expand(N, M, 2),
)
wh = rb - lt # [N,M,2]
wh[wh < 0] = 0 # 两个box没有重叠区域 如果两个box没有重合,那么rb和lt就分别是同一个box的右下角和左上角的坐标,因此rb - lt肯定是小于0的
inter = wh[:,:,0] * wh[:,:,1] # [N,M]
area1 = (box1[:,2]-box1[:,0]) * (box1[:,3]-box1[:,1]) # (N,)
area2 = (box2[:,2]-box2[:,0]) * (box2[:,3]-box2[:,1]) # (M,)
area1 = area1.unsqueeze(1).expand(N,M) # (N,M)
area2 = area2.unsqueeze(0).expand(N,M) # (N,M)
iou = inter / (area1+area2-inter)
return iou
自己实现下
import torch
a = torch.randn((2,4))
b = torch.randn((3,4))
tensor([[-1.3839, -2.1049, 0.0442, 0.1294],
[ 3.6880, 1.6080, -0.0313, 0.2779]])
tensor([[ 0.3556, 0.4686, 1.0932, -2.3597],
[ 1.2610, 0.2251, -0.3971, 1.7352],
[-0.5320, -2.6367, 0.6560, 0.2212]])
print(b)
N = a.size(0)
M = b.size(0)
print(N, M)
tensor([[ 0.3556, 0.4686, 1.0932, -2.3597],
[ 1.2610, 0.2251, -0.3971, 1.7352],
[-0.5320, -2.6367, 0.6560, 0.2212]])
2 3
# a1 sigmoid
a1 = torch.sigmoid(a[:, :2].unsqueeze(1).expand(N, M, 2))
a1
tensor([[[0.2004, 0.1086],
[0.2004, 0.1086],
[0.2004, 0.1086]],
[[0.9756, 0.8331],
[0.9756, 0.8331],
[0.9756, 0.8331]]])
# b1 sigmoid
b1 = torch.sigmoid(b[:, :2].unsqueeze(0).expand(N, M, 2))
b1
tensor([[[0.5880, 0.6150],
[0.7792, 0.5560],
[0.3701, 0.0668]],
[[0.5880, 0.6150],
[0.7792, 0.5560],
[0.3701, 0.0668]]])
lt = torch.max(a1, b1)
因此rb部分也是类似的:
torch.max(a1, b1).size()
a2 = torch.sigmoid(a[:, 2:].unsqueeze(1).expand(N, M, 2))
b2 = torch.sigmoid(b[:, 2:].unsqueeze(0).expand(N, M, 2))
rb = torch.min(a2, b2)
print(lt)
print('-'*50)
print(rb)
wh = lt -rb
print('-'*50)
--------------------------------------------------
lt:
tensor([[[0.5880, 0.6150],
[0.7792, 0.5560],
[0.3701, 0.1086]],
[[0.9756, 0.8331],
[0.9756, 0.8331],
[0.9756, 0.8331]]])
--------------------------------------------------
rb:
tensor([[[0.5110, 0.0863],
[0.4020, 0.5323],
[0.5110, 0.5323]],
[[0.4922, 0.0863],
[0.4020, 0.5690],
[0.4922, 0.5551]]])
--------------------------------------------------
wh:
tensor