PyTorch库学习之torch.repeat_interleave函数

PyTorch库学习之torch.repeat_interleave函数

一、简介

torch.repeat_interleave 是 PyTorch 库中的一个函数,它用于重复张量中的元素。这个函数可以沿着指定的维度重复张量中的每个元素,返回一个新的张量。当不指定维度时,会将输入张量展平,并重复每个元素。这个函数在处理序列数据或生成数据增强样本时非常有用。

二、语法和参数

语法:

torch.repeat_interleave(input, repeats, dim=None) → Tensor

参数:

  • input (torch.Tensor): 输入张量。
  • repeats (int 或 torch.Tensor): 每个元素的重复次数。如果 repeats 是一个整数,则所有元素都将重复相同的次数;如果是一个张量,则需要与 input 张量的形状相同,并且会被广播以适应输入张量的维度。
  • dim (int, 可选): 重复操作的维度。如果不指定 (None),则默认将整个张量视为一维。

返回值:

  • 返回一个新的张量,其形状与输入张量相同,但沿给定维度 dim 的大小会根据重复次数进行调整。

三、实例

3.1 重复一维张量中的每个元素
import torch
x = torch.tensor([1, 2, 3])
result = torch.repeat_interleave(x, 2)

print(result.shape)
print(result)

输出:

torch.Size([6])
tensor([1, 1, 2, 2, 3, 3])
3.2 沿着指定维度重复二维张量的元素
import torch

y = torch.tensor([[1, 2], [3, 4]])
result = torch.repeat_interleave(y, 3, dim=1)

print(result.shape)
print(result)

输出:

torch.Size([2, 6])
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
3.3 使用不同重复次数重复二维张量的行
import torch

y = torch.tensor([[1, 2], [3, 4]])
repeats_per_row = torch.tensor([2, 3])
result = torch.repeat_interleave(y, repeats_per_row, dim=0)

print(result.shape)
print(result)

输出:

torch.Size([5, 2])
tensor([[1, 2],
        [1, 2],
        [3, 4],
        [3, 4],
        [3, 4]])

四、注意事项

  • 如果 repeats 是一个张量,它必须是一维的,并且其长度必须与 input 张量在 dim 维度上的大小相同 。
  • dim 参数未指定时,repeats 必须是一个整数,不能是一个数组 。
  • 返回的张量与输入张量在除了 dim 维度以外的其他维度上具有相同的形状 。
  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值