PyTorch中squeeze()和unsqueeze()函数理解

squeeze(arg)
表示若第arg维的维度值为1,则去掉该维度,否则tensor不变。(即若tensor.shape()[arg] == 1,则去掉该维度)
例如:
一个维度为2x1x2x1x2的tensor,不用去想它长什么样儿,squeeze(0)就是不变,squeeze(1)就是变成2x2x1x2。(0是从最左边的维度算起的)

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])

unsqueeze(arg)
与squeeze(arg)函数作用相反,表示在第arg维增加一个维度为1的维度。
啥意思呢?
比如一个tensor的shape为3x3,那么unsqueeze(0)就是变成1x3x3,unsqueeze(1)就是变成3x1x3.
再如下面这个官方的例子,得看好几眼才能看明白怎么回事。
其实可以这样理解:x的shape为:4,unsqueeze(0)就是把shape变成1x4;unsqueeze(1)就是把shape变成4x1。

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])

参考:
[1] https://pytorch.org/docs/1.11/generated/torch.unsqueeze.html#torch.unsqueeze
[2] https://www.cnblogs.com/sbj123456789/p/9231571.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值