pytorch : expand 和 repeat 函数

expand 函数

expand(*sizes) -> Tensor
*sizes(torch.Size or int) - the desired expanded size
Returns a new view of the self tensor with singleton dimensions expanded to a larger size.

expand用于扩展tensor数据。但有以下注意点:

  1. 该函数不复制数据
  2. 扩展时只在能度数是1的维度上扩展
  3. 生成的对象与原对象共享内存
import torch

a = torch.tensor([1,2])
b = a.expand(2,-1) # -1 代表此维度不变
print('b : ', b)
print('\nafter modifying b')
b[0][0]=10
print('a : ', a)
print('b : ', b)
    b :  tensor([[1, 2],
            [1, 2]])

    after modifying b
    a :  tensor([10,  2])
    b :  tensor([[10,  2],
            [10,  2]])
c = a.expand(2,4) # tesor 'a'最后一维维度是2,所以扩展时出错
RuntimeError: The expanded size of the tensor (4) must match the existing size (2) at non-singleton dimension 1.  Target sizes: [2, 4].  Tensor sizes: [2]

repeat 函数

repeat(*sizes) -> Tensor
*size(torch.Size or int) - The number of times to repeat this tensor along each dimension.
Repeats this tensor along the specified dimensions.

返回tensor在某个维度上扩展后的张量.注意:

  1. 此函数会生成新的数据变量,和原tensor不共享内存
d = a.repeat(2,2)
print('d: ', d)
d[0][0] = 10
print('\nafter modifying d ')
print('a: ', a)
print('d: ', d)

a.repeat(2,4) # 此参数 expand() 函数不通过
d:  tensor([[1, 2, 1, 2],
        [1, 2, 1, 2]])
        
after modifying d 
a:  tensor([1, 2])
d:  tensor([[10,  2,  1,  2],
        [ 1,  2,  1,  2]])

tensor([[1, 2, 1, 2, 1, 2, 1, 2],
        [1, 2, 1, 2, 1, 2, 1, 2]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值