2020-10-20

Torch.repeat函数理解

一、官方定义

官方文档定义如下:
repeat(*sizes) → Tensor
Repeats this tensor along the specified dimensions.

Unlike expand(), this function copies the tensor’s data.

WARNING

repeat() behaves differently from numpy.repeat, but is more similar to numpy.tile. For the operator similar to numpy.repeat, see torch.repeat_interleave().

Parameters
sizes (torch.Size or int…) – The number of times to repeat this tensor along each dimension

Example:

x = torch.tensor([1, 2, 3])
x.repeat(4, 2)
tensor([[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3]])

x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])

二、理解

官方文档中提到与expand不同的是,expand函数仅会改变tensor的视图,而repeat会拷贝原tensor的数据。
从几个简单的例子能帮助更好的理解。

import torch
In [1]: import torch                                                                                                                                                                                                                                                 

In [2]: a = torch.randint(5,(2,3))                                                                                                                                                                                                                                   

In [3]: a.repeat(1,1)                                                                                                                                                                                                                                                
Out[3]: 
tensor([[2, 1, 3],
        [4, 3, 1]])#在各自维度复制一个,则保持不变。
In [4]: a.repeat(1,1,1)                                                                                                                                                                                                                                              
Out[4]: 
tensor([[[2, 1, 3],
         [4, 3, 1]]])#加了一个维度。
In [5]: a.repeat(1,2,3)                                                                                                                                                                                                                                              
Out[5]: 
tensor([[[2, 1, 3, 2, 1, 3, 2, 1, 3],
         [4, 3, 1, 4, 3, 1, 4, 3, 1],
         [2, 1, 3, 2, 1, 3, 2, 1, 3],
         [4, 3, 1, 4, 3, 1, 4, 3, 1]]])
         #在1维度复制1,2维度复制2,3维度3个       

最后一个理解可以这样,最后输出的张量第一个维度为1保持。
第二个维度需要复制两个,第三个维度复制三个。按照numpy.tile来理解就是在维度上堆砌相同的张量,从最后一个维度来看 横向堆砌三个张量,原张量本来应该为(2,3)所以现在应该是(2,9),横向堆砌两个,所以维度变为(4,9),最终输出为(1,4,9)。
再举一个例子帮助理解。

In [8]: a.repeat(4,2,1)                                                                                                                                                                                                                                              
Out[8]: 
tensor([[[2, 1, 3],
         [4, 3, 1],
         [2, 1, 3],
         [4, 3, 1]],

        [[2, 1, 3],
         [4, 3, 1],
         [2, 1, 3],
         [4, 3, 1]],

        [[2, 1, 3],
         [4, 3, 1],
         [2, 1, 3],
         [4, 3, 1]],

        [[2, 1, 3],
         [4, 3, 1],
         [2, 1, 3],
         [4, 3, 1]]])#输出形状为(4,4,3)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值