torch.repeat_interleave()与tensor.repeat()——数组的重复

torch.repeat_interleave()与 tensor.repeat()——数组的重复

torch.repeat_interleave()

torch.repeat_interleave(input, repeats, dim=None) → Tensor

功能:沿着指定的维度重复数组的元素

输入:

  • input:指定的数组
  • repeats:每个元素重复的次数,可以是张量或者是数组
  • dim:指定的维度

注意:

  • 如果不指定dim,则默认将输入数组扁平化(维数是1,因此这时repeats必须是一个数,不能是数组),并且返回一个扁平化的输出数组

  • 返回的数组与输入数组维数相同,并且除了给定的维度dim,其他维度大小与输入数组相应维度大小相同

  • repeats:如果传入数组,则必须是tensor格式。并且只能是一维数组,数组长度与输入数组inputdim维度大小相同,输入数组的具体意义如下:
    如果 r e p e a t s = [ n 1 , n 2 , … , n m ] , 则输出 [ x 1 , x 1 , … , x 1 , x 2 , x 2 , … , x m ] 其中, x 1 重复 n 1 次, x 2 重复 n 2 次, x m 重复 n m 次 如果repeats=[n_1,n_2,\dots,n_m],则输出[x_1,x_1,\dots,x_1,x_2,x_2,\dots,x_m]\\其中,x_1重复n_1次,x_2重复n_2次,x_m重复n_m次 如果repeats=[n1,n2,,nm],则输出[x1,x1,,x1,x2,x2,,xm]其中,x1重复n1次,x2重复n2次,xm重复nm

案例代码

一般用法

import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,3,dim=0)
c=torch.repeat_interleave(a,3,dim=1)
print(a)
print(b)
print(c)
print(a.shape)
print(b.shape)
print(c.shape)

输出

# 原数组
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
# 沿第一维度重复后的数组
tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9],
        [5, 6, 7, 8, 9],
        [5, 6, 7, 8, 9]])
# 沿第二维度重复后的数组
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4],
        [5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]])
# 原数组形状
torch.Size([2, 5])
# 沿第一维度重复后的形状
torch.Size([6, 5])
# 沿第二维度重复后的形状
torch.Size([2, 15])

当不指定dim

import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,2)
print(a)
print(b)

输出

# 原数组
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
# 不指定dim时重复两次
tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9])

repeats为数组格式时

import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,torch.tensor([2,3]),dim=0)
print(a)
print(b)

输出

# 原数组
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
# 第一行重复两次,第二行重复三次
tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9],
        [5, 6, 7, 8, 9],
        [5, 6, 7, 8, 9]])

如果repeats为数组,但是大小和输入的dim大小不匹配,则会报错

import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,torch.tensor([2,3]),dim=1)
print(a)
print(b)

输出报错,RuntimeError:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-c8f2c85e38df> in <module>
      1 import torch
      2 a=torch.arange(10).view(2,5)
----> 3 b=torch.repeat_interleave(a,torch.tensor([2,3]),dim=1)
      4 print(a)
      5 print(b)

RuntimeError: repeats must have the same size as input along dim

torch.Tensor.repeat()

Tensor.repeat(*sizes) → Tensor

功能:沿每个维度重复张量数组

输入:

  • sizes:沿每个维度重复此张量的次数

注意:

  • sizes长度必须大于等于被重复数组tensor的维数(如果tensor的维数是2,则sizes就必须大于等于2)

代码案例

import torch
a=torch.arange(10).view(2,5)
b=a.repeat(2,3,2)
print(a)
print(b)
print(a.shape)
print(b.shape)

输出

# 原数组
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
# 重复后的数组
tensor([[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9]],

        [[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9]]])
# 原数组形状
torch.Size([2, 5])
# 重复后的数组形状
torch.Size([2, 6, 10])

如果sizes长度小于tensor的维数,则会报错

import torch
a=torch.arange(10).view(2,5)
b=a.repeat(2)

输出报错

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-33-19a278098c7c> in <module>
      1 import torch
      2 a=torch.arange(10).view(2,5)
----> 3 b=a.repeat(2)

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

区别

两个函数方法最大的区别就是repeat_interleave是一个元素一个元素地重复,而repeat是一组元素一组元素地重复

官方文档

torch.repeat_interleave():https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html#torch.repeat_interleave

torch.tensor.repeat():https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html#torch.Tensor.repeat

点个赞支持一下吧

  • 11
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
torch.repeat_interleave()函数是PyTorch中的一个函数,用于将输入张量中的元素重复指定次数。函数的原型为torch.repeat_interleave(input, repeats, dim=None),其中input是输入张量,repeats是每个元素的重复次数,dim是需要重复的维度,默认情况下dim=None,表示将输入张量展平为向量,然后将每个元素重复repeats次,并返回重复后的张量。\[1\] 举例说明: ``` x = torch.tensor(\[1, 2, 3\]) x.repeat_interleave(2) # 输出: tensor(\[1, 1, 2, 2, 3, 3\]) y = torch.tensor(\[\[1, 2\], \[3, 4\]\]) torch.repeat_interleave(y, 2) # 输出: tensor(\[1, 1, 2, 2, 3, 3, 4, 4\]) torch.repeat_interleave(y, 3, dim=0) # 输出: tensor(\[\[1, 2\], \[1, 2\], \[1, 2\], \[3, 4\], \[3, 4\], \[3, 4\]\]) torch.repeat_interleave(y, 3, dim=1) # 输出: tensor(\[\[1, 1, 1, 2, 2, 2\], \[3, 3, 3, 4, 4, 4\]\]) torch.repeat_interleave(y, torch.tensor(\[1, 2\]), dim=0) # 输出: tensor(\[\[1, 2\], \[3, 4\], \[3, 4\]\]) ``` 以上是一些使用torch.repeat_interleave()函数的示例,可以根据需要指定重复次数和重复的维度来实现不同的重复操作。\[2\] 注意:在传入多维张量时,函数会默认将其展平为向量进行重复操作。\[2\] #### 引用[.reference_title] - *1* *3* [Pytorchtorch.repeat_interleave()函数解析](https://blog.csdn.net/flyingluohaipeng/article/details/125039411)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [torch.repeat_interleave()函数详解](https://blog.csdn.net/weixin_43823669/article/details/126283277)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值