torch.repeat()
torch.repeat()
函数用于在指定的维度上复制张量的元素。
语法:
torch.repeat(*sizes)
其中 *sizes
是一个可变数量的参数,用于指定张量在每个维度上重复的次数。
返回一个新的张量,其形状由输入张量和指定的重复次数决定
举例:
import torch
# 创建一个张量
x = torch.tensor([[1, 2], [3, 4]])
# 在维度0上重复2次,在维度1上重复3次
y = x.repeat(2, 3)
print(y)
结果:
tensor([[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4],
[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4]])