【Pytorch】 repeat()的用法详解

Pytorch的repeat()方法再深度学习中经常用到,用于复制tensor,最好的说明当然是官方文档。

repeat的用法说明很简单:重复每个张量的维度的次数。

-这里有个warrning很有意思,意思是Pytorch的repeat和numpy.repeat是不太一样的。下次填坑。

在这里插入图片描述
看官方给的例子:

import torch

x = torch.tensor([1, 2, 3])
print(x.shape)
# torch.Size([3])

print(x.repeat(4, 2))
"""
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])"""
        
x.repeat(4, 2, 1).size()
#torch.Size([4, 2, 3])

x是一维的tensor,但传入repeat的size是二维的即(4,2)维度时不对应的,看一下复制流程,x是一维tensor,但是可以看成是二维的,新增的维度的值为1。举个例子相当于把一个n维向量(行向量)看作一个一行n列的矩阵,向量是一维但矩阵是二维的。

import torch

# 原始x是一维的张量
x = torch.tensor([1, 2, 3])

# 把x的维数增加一维变成二维
x = x.reshape(1,3)

x.repeat(4,2)
# 得到相同的结果
"""
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])"""

repeat(4,2)相当于把整个tensor在行方向上复制4次,在列方向上复制2次。注意是整个tensor,而不是复制完一行接着复制下一行。

# x.shape (2,3)
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
                  
# 维度对应时在相应的维度复制即可
x.repeat(4, 2)

"""
tensor([[1, 2, 3, 1, 2, 3],
        [4, 5, 6, 4, 5, 6],
        
        [1, 2, 3, 1, 2, 3],
        [4, 5, 6, 4, 5, 6],
        
        [1, 2, 3, 1, 2, 3],
        [4, 5, 6, 4, 5, 6],
        
        [1, 2, 3, 1, 2, 3],
        [4, 5, 6, 4, 5, 6]])

"""

再看官方例子的最后一行代码:

import torch

# 此时x是一维的
x = torch.tensor([1, 2, 3])
# 复制的是三维的
x.repeat(4, 2, 1)

# 和上面例子是一样的,先把x升到3维
x = torch.tensor([1, 2, 3])

# 把x看成是一个1通道1行3列的三维张量
x = x.reshape(1,1,3)

x.repeat(4, 2, 1)
# 对应维度复制即可得到结果
# x变成了4通道2行3列的张量
"""
tensor([[[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]]])
 """
  • 8
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch中,repeat()函数可以对张量进行重复扩充。它有两种用法: 1. 当参数只有两个时,即(x.repeat(a, b)),其中a表示行的重复倍数,b表示列的重复倍数。例如,x.repeat(4, 2)会将x在行方向上重复4倍,在列方向上重复2倍。 2. 当参数有三个时,即(x.repeat(a, b, c)),其中a表示通道数的重复倍数,b表示行的重复倍数,c表示列的重复倍数。例如,x.repeat(4, 2, 1)会将x在通道数上重复4倍,在行方向上重复2倍,在列方向上不重复。 下面是一个代码例子: ```python import torch x = torch.tensor([1, 2, 3]) print(x.shape) # torch.Size([3]) print(x.repeat(4, 2)) """ tensor([[1, 2, 3, 1, 2, 3], [1, 2, 3, 1, 2, 3], [1, 2, 3, 1, 2, 3], [1, 2, 3, 1, 2, 3]]) """ print(x.repeat(4, 2, 1).size()) # torch.Size([4, 2, 3]) ``` 总结起来,repeat()函数可以根据传入的倍数,在指定的维度上对张量进行重复扩充。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [Pytorch中torch.repeat()函数解析](https://blog.csdn.net/flyingluohaipeng/article/details/125039368)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [【Pytorchrepeat()的用法详解](https://blog.csdn.net/m0_46412065/article/details/128043821)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [pytorchrepeat方法](https://blog.csdn.net/weixin_42060572/article/details/114254532)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值