pytorch中的repeat()函数可以对张量进行复制。
当参数只有一个时,参数表示在列方向上复制的次数
当参数只有两个时,第一个参数表示在行方向上复制的次数,第二个参数表示在列方向上复制的次数。
当参数有三个时,第一个参数表示在通道数方向上复制的次数,第二个参数表示在行方向上复制的次数,第三个参数表示在列方向上复制的次数。
接下来我们举一个例子来直观理解一下:
>>> x = torch.tensor([6,7,8])
>>> x.repeat(4)
tensor([[6, 7, 8, 6, 7, 8, 6, 7, 8, 6, 7, 8]])
>>> x = torch.tensor([6,7,8])
>>> x.repeat(4,2)
tensor([[6, 7, 8, 6, 7, 8],
[6, 7, 8, 6, 7, 8],
[6, 7, 8, 6, 7, 8],
[6, 7, 8, 6, 7, 8]])
>>> x = torch.tensor([6,7,8])
>>> x.repeat(4,1)
tensor([[6, 7, 8],
[6, 7, 8],
[6, 7, 8],
[6, 7, 8]])
>>> x.repeat(4,2,1)
tensor([[[6, 7, 8],
[6, 7, 8]],
[[6, 7, 8],
[6, 7, 8]],
[[6, 7, 8],
[6, 7, 8]],
[[6, 7, 8],
[6, 7, 8]]])
>>> x.repeat(4,2,1).size()
torch.Size([4, 2, 3])

1463

被折叠的 条评论
为什么被折叠?



