pytorch语法unsqueeze、squeeze和gather

今天看网上DCN的解读,顺带看了代码,决定这次把unsqueeze和gather搞明白。
unsqueeze的作用跟字面意思一样,与squeeze作用相反,就是扩维,请看代码。
>>> import torch
>>> a = torch.tensor([1,2,3,4,5])
>>> a.shape
torch.Size([5])
可以看到a的维度为1.
>>> b=a.unsqueeze(dim=0)
>>> b
tensor([[1, 2, 3, 4, 5]])
在第0维对a扩维,可以看到a的维度变为了2,再看一下shape.
>>> b.shape
torch.Size([1, 5])
可以看到在原先的第0维变为了第1维。
>>> c=b.unsqueeze(dim=-1)
>>> c
tensor([[[1],
         [2],
         [3],
         [4],
         [5]]])
在最后一维对b进行扩维,再看一下shape。
>>> c.shape
torch.Size([1, 5, 1])
可以看到c的维度增加了,在最后一维,再向c加一个维度.
>>> d=c.unsqueeze(dim=-1)
>>> d
tensor([[[[1]],
         [[2]],
         [[3]],
         [[4]],
         [[5]]]])
>>> d.shape
torch.Size([1, 5, 1, 1])
再验证下squeeze。
>>> e = d.squeeze(0)
>>> e.shape
torch.Size([5, 1, 1])
>>> e = d.squeeze(1)
>>> e.shape
torch.Size([1, 5, 1, 1])
>>> e = d.squeeze(-1)
>>> e.shape
torch.Size([1, 5, 1])
>>> e = d.squeeze(-2)
>>> e.shape
torch.Size([1, 5, 1])
可以看到squeeze只能将多余的维度降维,包含数据的维度是没有用的。再看一下gather。
gather的作用是按index读取input中的数据。index要求与input维度相同。
>>> a
tensor([[1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6],
        [4, 5, 6, 7]])
>>> index
tensor([[0],
        [1],
        [2],
        [3]])
>>> a.shape
torch.Size([4, 4])
>>> index.shape
torch.Size([4, 1])
首先给定一个输入a,维度为2.index的维度也为2,两者维度要相同,而且若要gather对input第i维进行操作,那么index在第i维的维度应大于等于1.(翻译文档)因为若对input第i维进行操作,input的第i维定不为空,这样index在该维度就不能没有数值,也就是大于等于1.
>>> out = torch.gather(a,0,index)
>>> out
tensor([[1],
        [2],
        [3],
        [4]])
>>> out = torch.gather(a,1,index)
>>> out
tensor([[1],
        [3],
        [5],
        [7]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值