torch.Tensor.repeat 函数详解
-
官方文档:torch.Tensor.repeat
-
函数原型
Tensor.repeat(*sizes) → Tensor
-
功能
沿指定维度复制张量 -
参数
参数 类型 说明 sizes torch.Size or int...
沿每个维度重复该张量的次数 -
说明
假定参数sizes
的长度为d
,待执行操作的张量维度为dim
,则存在以下两种情况:d == dim
—— 每个维度执行对应次的复制;d > dim
—— 通过对张量添加新轴将其提升至d
维,再执行复制。- 特别说明:有别于
numpy.tile
,pytorch
不支持d < dim
的情况,numpy.tile
会通过在d
前面添加1
将其提升至dim
维,再执行复制(参见numpy.tile)。
-
代码示例
>>> a = torch.tensor([[1, 3], [2, 4]]) >>> a_1 = a.repeat(3, 2) >>> a_2 = a.repeat(2, 1, 3) >>> a.shape torch.Size([2, 2]) >>> a_1.shape torch.Size([6, 4]) >>> a_2.shape torch.Size([2, 2, 6]) >>> a tensor([1, 2, 3]) >>> a_1 tensor([[1, 3, 1, 3], [2, 4, 2, 4], [1, 3, 1, 3], [2, 4, 2, 4], [1, 3, 1, 3], [2, 4, 2, 4]]) >>> a_2 tensor([[[1, 3, 1, 3, 1, 3], [2, 4, 2, 4, 2, 4]], [[1, 3, 1, 3, 1, 3], [2, 4, 2, 4, 2, 4]]])