【pytorch】squeeze()和unsqueeze()函数介绍

在pytorch中,我们对张量Tensor的维度进行压缩或者扩充(被压缩或者扩充的维度为1),经常使用的是squeeze()函数和unsqueeze()函数

1. torch.squeeze(input, dim=None)

 

用于降维。将 input 中维度为1的部分去除,当维度大于等于2时,squeeze()无作用。

也可通过 input.squeeze( dim=None, out=None)调用。

  • input(Tensor):输入张量,即被操作目标
  • dim(int, optional):在指定维去掉一个维度。若不指定则自动寻找,指定则当指定的维度为1时去掉,不为1时则不改变

注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

 示例

# 示例1
a = torch.Tensor(1,3)
>>
tensor([[-1.37,4.56,-3.57]])
 
print a.squeeze(0) #第一个维度大小是1,所以去除
>>
tensor([-1.37,4.56,-3.57])
 
print a.squeeze(1) ##第二个维度大小是3,所以不去除
>>
tensor([[-1.37,4.56,-3.57]])
 
# 示例2
c = torch.Tensor(3,1)
print c
>>
tensor([[-3.54],
[3.09],
[0.00]])
 
print c.squeeze(0)##第一个维度大小不是1,所以不去除
>>
tensor([[-3.54],
[3.09],
[0.00]])
 
print c.squeeze(1)#第二个维度大小是1,所以去除
>>
tensor([-3.54,3.09,0.00])


# 示例3
x = torch.zeros(2, 1, 2, 1, 2)
x.size()
>>
torch.Size([2, 1, 2, 1, 2])

y = torch.squeeze(x)
y.size()
>>
torch.Size([2, 2, 2])

y = torch.squeeze(x, 0)
y.size()
>>
torch.Size([2, 1, 2, 1, 2])

2.  torch.unsqueeze(input, dim)

为pytorch中的tensor增加一个维度。

 也可通过 input.unsqueeze( dim=None, out=None)调用。

  • input(Tensor):输入张量,即被操作目标
  • dim(int, optional):在哪一维增加一个维度,dim必须被指定

示例

import torch
a = torch.arange(12).reshape([3,4])
print(a)
b = a.unsqueeze(1)
print(b)
>>
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]]])

参考官方文档:

torch.squeeze — PyTorch 2.0 documentation

torch.unsqueeze — PyTorch 2.0 documentation

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值