torch.repeat()、torch.repeat_interleave()和torch.expand()

文章介绍了PyTorch中用于复制张量元素的三个函数:repeat_interleave可以在指定轴上重复张量元素,repeat用于整行或整列复制,而expand则返回张量的一个扩容视图。通过实例展示了各函数的用法和参数含义。
摘要由CSDN通过智能技术生成

1. torch.repeat_interleave(input, repeats, dim=None) 该函数用来复制张量元素

  • 参数
    – input: 输入张量
    – repeats: 对张量的每个元素复制的次数,通过广播机制实现,可以是int类型,也可以是张量类型。指定每行的复制次数。
    – dim:指定对哪个轴上的元素进行复制,默认将输入数组展开复制并返回一个展开后的输出数组

1.1 不指定轴

a = torch.tensor([[1, 2, 3], [4, 5, 6]])
a.repeat_interleave(2)
# tensor([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])

1.2 指定轴

1.2.1 按行
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
a.repeat_interleave(2, 1)
# tensor([[1, 1, 2, 2, 3, 3],
#        [4, 4, 5, 5, 6, 6]])
1.2.2 按列
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
a.repeat_interleave(2, 0)
tensor([[1, 2, 3],
        [1, 2, 3],
        [4, 5, 6],
        [4, 5, 6]])
1.3 使用张量指定行的复制次数
aa = torch.tensor([[1, 2], [3, 4], [5, 6]])
torch.repeat_interleave(aa, torch.tensor([3, 2, 2]), dim=0)
tensor([[1, 2],
        [1, 2],
        [1, 2],
        [3, 4],
        [3, 4],
        [5, 6],
        [5, 6]])

2 torch.repeat(*sizes) 整行/整列复制

  • 例子一
bb = torch.tensor([1, 2, 3])
bb.repeat(4, 2)
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])
  • 例子二
cc = bb.repeat(4, 2, 1)
cc.shape  # (4, 2, 3)
tensor([[[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]]])
         

3. torch.expand(*sizes)

返回原张量扩容后的一个新的视图

  • 例子一
x = torch.tensor([[1], [2], [3]])
x.size()  # (3, 1)
b = x.expand(3, 4)  # 等价于x.expand(-1, 4), -1 代表不变
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值