一、expand()
返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小(size)等于1的维度扩展到更大的尺寸。
x = torch.tensor([1, 2, 3])
>> x.expand(2, 3)
tensor([[1, 2, 3],
[1, 2, 3]])
二、repeat()
沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据。
import torch
>> x = torch.tensor([1, 2, 3])
>> x.repeat(3, 2)
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
>> x2 = torch.randn(2, 3, 4)
>> x2.repeat(2, 1, 3).shape
torch.Tensor([4, 3, 12])
三、None
在对应维度增加一个维度
bboxes_points 6824x2
a = bboxes_points[:, None, :] 6824x1x2
gt_points 3x2
b = gt_points[None, :, :] 1x3x2
c = a - b (这里使用pytorch的广播机制)
注意:这里none在哪个位置上,就在哪个维度上增加