Pytorch(二)torch.repeat
前几天遇到了这个函数,感觉别人说的也不是特别清楚,记录一下
首先repeat是在原始张量上进行扩充的,然后我们再从维度的角度上去理解,第几个参数的值表示在的几维上根据原始张量进行扩充
import torch
x = torch.tensor([1, 2, 3])
print(x.shape)
# torch.Size([3])
print(x)
# tensor([1, 2, 3])
x1 = x.repeat(4) # 相当于在第0个维度上扩张4次
print("x1.shape:", x1.shape)
# x1.shape: torch.Size([12]) 得到了3*4,12维的张量
print("x1:", x1)
# x1: tensor([1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
x2 = x.repeat(3, 2)
# 当有两个参数时,因为之前的x只有1维,现在扩张成2维,先在第0维扩张3次,再在第1维扩张2次,参考下图
print("x2.shape:", x2.shape)
# x2.shape: torch.Size([3, 6])
print('x2:', x2)
# x2: tensor([[1, 2, 3, 1, 2, 3],
# [1, 2, 3, 1, 2, 3],
# [1, 2, 3, 1, 2, 3]])
三维的我举个例子,假设单通道图片,我想复制成三通道的
import torch
img = torch.arange(4).reshape(2, -1)
print(img.shape)
# torch.Size([2, 2]) H*W的图 2*2
print(img)
# tensor([[0, 1],
# [2, 3]])
repeat_img = img.repeat(3, 1, 1) # 相当于0维扩3次,1,2维不变
print(repeat_img.shape)
# torch.Size([3, 2, 2]) (C,H,W)
print(repeat_img)
# tensor([[[0, 1],
# [2, 3]],
#
# [[0, 1],
# [2, 3]],
#
# [[0, 1],
# [2, 3]]])
参考:https://blog.csdn.net/weixin_41041772/article/details/123296659