pytorch中的repeat指的是在各指定维度上重复指定次数,例如有一个tensor(形状为(3,)):
t=torch.arange(3)
print(t) # tensor([0, 1, 2])
我们如果想让他在第一维上重复3次,也就是变成[0, 1, 2, 0, 1, 2, 0, 1, 2],需要使用
t=t.repeat((3,))
如果想让他变成纵向的重复三次,也就是变成
tensor([[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])
则需要先将该tensor抬升一个维度,变成形状为(1,3),再让它在第0维上重复:
t=t.unsqueeze(0) #增加一个维度
t=t.repeat(3,1)
print(t)