Tensor.repeat(sizes) → Tensor
在具体的维度上重复tentor,废话不多说,直接上例子
输入tentor
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
print (x)
print (x.shape)
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
torch.Size([3, 4])
由于x是二维的,repeat的size是三维的,因此自动的将x reshape为(1, 3, 4),(3, 1, 1)的意思是,在第一个维度上repeat 3倍, 在第二个维度上repeat 1倍,在第三个维度上repeat 1倍
a = x.repeat(3, 1, 1)
print (a)
print (a.shape)
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]]])
torch.Size([3, 3, 4])
另一种repeat方式
b = x.unsqueeze(1)
print (b.shape)
b = b.repeat(1, 3, 1)
print (b)
print (b.shape)
torch.Size([3, 1, 4])
tensor([[[ 1, 2, 3, 4],
[ 1, 2, 3, 4],
[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8],
[ 5, 6, 7, 8],
[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12],
[ 9, 10, 11, 12],
[ 9, 10, 11, 12]]])
torch.Size([3, 3, 4])