repeat
在 PyTorch 中,repeat()
函数用于复制张量的维度。它会将输入张量按照指定的次数重复,以生成一个新的张量。
这是 repeat()
函数的基本语法:
repeat(*sizes)
sizes
: 重复每个维度的次数。如果你只想在某些维度上重复,可以在对应位置填入1
。
下面是一个简单的例子,说明了 repeat()
函数的用法:
import torch
# 创建一个张量
x = torch.tensor([[1, 2],
[3, 4]])
# 在每个维度上分别重复 2 次
y = x.repeat(2, 2)
print(y)
输出结果是:
tensor([[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]])
在这个例子中,原始张量 x
是一个 2x2 的矩阵。通过 x.repeat(2, 2)
,我们在每个维度上分别重复了 2 次,得到了一个新的 4x4 的张量 y
。
a = torch.tensor(
[[1,2,3],
[4,5,6]]
)
print(a)
x = a.repeat(1,1)
print(x)
x = a.repeat(1,2)
print(x)
x = a.repeat(2,1)
print(x)
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6]])
tensor([[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6]])