今天在看tensor版本的交并比,它比单纯的python版本的list使用更加灵活,但是看的时候也遇到了一个问题,就是unsqueeze,这个方法就是根据指定的维度进行扩展一维,比如维度为:
A =(2,4),A.unsqueeze(0),则变为 (1,2,4)
A.unsqueeze(1),则变为 (2,1,4)
A = box_a.size(0)
B = box_b.size(0)
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
其中
print(box_a.size()) print(box_a) print(box_b.size()) print(box_b)
值如下
torch.Size([3, 4])
tensor([[0.0020, 0.1440, 0.8280, 0.9973],
[0.3900, 0.7920, 0.4840, 0.9973],
[0.3860, 0.1600, 0.8300, 0.9973]])
torch.Size([2, 4])
tensor([[-0.0367, -0.0367, 0.0633, 0.0633],
[-0.0574, -0.0574, 0.0840, 0.0840]])、
但是在下面代码,我突然有点懵了,为啥不能写
max_xy = torch.min(box_a[:, 2:].unsqueeze(0).expand(A, B, 2), box_b[:, 2:].unsqueeze(1).expand(A, B, 2))
就是我很好奇,凭啥box_a是拓展第一维,box_b拓展第零维,但是如果改为上面那样是不行的,为啥,因为我们要看到后面有个expand,就是你拓展的一维要扩张成什么样。
如果按照我刚刚写的,那么我们细分下:
A =3, B = 2
box_a 的 torch.Size([3, 4])
box_a[:, 2:].unsqueeze(0) 的 size() = (1,3,4)
box_b 的 torch.Size([2, 4])
box_b[:, 2:].unsqueeze(1) 的 size() = (2,1,4)
但是,这个时候,后面expand(A, B, 2) 即(3,2,2)就不行了,
因为 box_a[:, 2:].unsqueeze(0) = (1,3,4)第一维没办法拓展维2
同理 box_b[:, 2:].unsqueeze(1) = (2,1,4)第零维也没办法拓展维3
但是其实是可以的,只要把 expand(A, B, 2),改为expand(B, A, 2)=(2,3,2)这样就可以了。