repeat(*sizes)
沿着指定的维度重复tensor
看下面的例子理解更容易!
import torch
x = torch.tensor([1, 2, 3])
print(x)
print(x.shape)
print(x.repeat(3))
print("###################################")
print(x.repeat(3, 1))
print("###################################")
print(x.repeat(3,2))
print("###################################")
print(x.repeat(3,2,1))
输出结果:
tensor([1, 2, 3])
torch.Size([3])
tensor([1, 2, 3, 1, 2, 3, 1, 2, 3])
###################################
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
###################################
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
###################################
tensor([[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]]])
import torch
x = torch.tensor([1, 2, 3])
print(x)
print(x.shape)
print("###################################")
B = x.repeat(1,5,1)
print(B)
print(B.shape)
print("###################################")
B = x.repeat(2,5,2)
print(B)
print(B.shape)
输出结果: