torch中repeat()函数的准确理解

pytorch中,repeat()这个函数的功能是很容易理解的,但是函数在不同参数下的执行逻辑通常还是会让初次接触的同学有点疑惑,今天仔细想了一下,感觉可以以这样的方式去理解:

import torch

x = torch.tensor([1,2,3])

#将一维度的x扩展到三维
xx = x.repeat(4,2,1)

/**
扩展步骤如下(倒着执行):
1  最后一个维度1:此时将[1,2,3]中的数字直接重复1次,得到[1,2,3],保持没变
2  倒数第二个维度2:先将上一步骤的结果增加一个维度,得到[[1,2,3]],然后将最外层中括号中的整体重复2次,得到[[1,2,3],[1,2,3]]
3  倒数第三个维度4:先将上一步骤的结果增加一个维度,得到[[[1,2,3],[1,2,3]]],然后将最外层中括号中的整体重复4次,得到[[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]]]
4  三个维度扩展结束,得到结果。

**/

测试代码如下:

import torch

x = torch.tensor([1,2,3])

x1 = x.repeat(4)
print("x1:\n",x1)

x2 = x.repeat(4,1)
print("x2:\n"x2)

x3 = x.repeat(4,2)
print("x3:\n"x3)

x4 = x.repeat(4,2,1)
print("x4:\n"x4)

测试代码结果:

x1:
 tensor([1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]) 
x2:
 tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])
x3:
 tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])
x4:
 tensor([[[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]]])

  • 49
    点赞
  • 54
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值