相当于手动实现广播机制,即沿着给定的维度对tensor进行重复:
比如说对下面x的第1个通道复制三次,其余通道保持不变:
import torch
x = torch.randn(1, 3, 224, 224)
y = x.repeat(3, 1, 1, 1)
print(x.shape)
print(y.shape)
结果为:
torch.Size([1, 3, 224, 224])
torch.Size([3, 3, 224, 224])
这个在复制batch的时候用的比较多,上面的情况就相当于batch为1的3×224×224特征图复制成了batch为3