pytorch repeat()函数的功能演示,非常详细(代码验证)

PyTorch中的repeat()函数可以对张量进行重复扩充。

repeat()中的参数个数需 >=  tensor 维度

举例证明:

>>> abb = torch.rand([2,1,2,3])
>>> print("abb: ",abb)
abb:  tensor([[[[0.1103, 0.8384, 0.8689],
          [0.2282, 0.9317, 0.4930]]],


        [[[0.8018, 0.8036, 0.1346],
          [0.9017, 0.4192, 0.6952]]]])
>>> print("abb.shape: ",abb.shape)
abb.shape:  torch.Size([2, 1, 2, 3])
>>> bb = abb.repeat(1,2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>

看看 错误原因:RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor.

当参数只有两个时:(行的倍数,列的倍数), 此时通道数倍数为1,表示不重复扩充。

当参数有三个时:(通道数的倍数,行的倍数,列的倍数)。

直接将repeat参数对应乘上原tensor的shape就得到最终的shape,如:
shape(2,3) * repeat(1,2) = shape(2,6)  //channel = 1

shape(2,3) * repeat(2,2,1) = shape(2,4,3)  // shape(2,3) 可看作shape(1,2.3)
 

>>> bb = abb.repeat(2,2,2,3)
>>> bb
tensor([[[[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
           0.8689],
          [0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
           0.4930],
          [0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
           0.8689],
          [0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
           0.4930]],

         [[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
           0.8689],
          [0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
           0.4930],
          [0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
           0.8689],
          [0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
           0.4930]]],


        [[[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
           0.1346],
          [0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
           0.6952],
          [0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
           0.1346],
          [0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
           0.6952]],

         [[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
           0.1346],
          [0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
           0.6952],
          [0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
           0.1346],
          [0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
           0.6952]]],


        [[[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
           0.8689],
          [0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
           0.4930],
          [0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
           0.8689],
          [0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
           0.4930]],

         [[0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
           0.8689],
          [0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
           0.4930],
          [0.1103, 0.8384, 0.8689, 0.1103, 0.8384, 0.8689, 0.1103, 0.8384,
           0.8689],
          [0.2282, 0.9317, 0.4930, 0.2282, 0.9317, 0.4930, 0.2282, 0.9317,
           0.4930]]],


        [[[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
           0.1346],
          [0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
           0.6952],
          [0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
           0.1346],
          [0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
           0.6952]],

         [[0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
           0.1346],
          [0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
           0.6952],
          [0.8018, 0.8036, 0.1346, 0.8018, 0.8036, 0.1346, 0.8018, 0.8036,
           0.1346],
          [0.9017, 0.4192, 0.6952, 0.9017, 0.4192, 0.6952, 0.9017, 0.4192,
           0.6952]]]])
>>> print("bb.shape: ",bb.shape)
bb.shape:  torch.Size([4, 2, 4, 9])
>>>

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值