pytorch中的squeeze和unsqueeze函数的使用

应用场景:当我们进行深度学习使用Image函数导入图片时,默认它的维度为[C, H, W],此时根据模型的需要导入batch这一维度。
部分程序
# 导入要测试的图像(自己找的,不在数据集中),放在源文件目录下
im = Image.open('dog.jpg')
im = transform(im)  # [C, H, W]
im = torch.unsqueeze(im, dim=0)      # 对数据增加一个新维度(batch),因为tensor的参数是[batch, channel, height, width]

unsqueeze 指定哪个下标就增加那个下标长度为1的维度(必须加上下标,)。

squeeze 从中删除长度为1的维度。指定下标就删除那个下标长度为1的维度,若不为1则不删除。

具体详情见如下代码:

unsquenze(增加长度1的维度):

import torch
a = torch.arange(9,dtype=float).reshape((3,3))
print(a)
print(a.shape)        # 查看a的形状
a = a.unsqueeze(0)    
print(a)              # 经过unsqueeze函数后的情况
print(a.shape)        # 维度变化情况

# 结果输出
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
tensor([[[0., 1., 2.],
         [3., 4., 5.],
         [6., 7., 8.]]], dtype=torch.float64)
torch.Size([1, 3, 3])
import torch
a = torch.arange(9,dtype=float).reshape((3,3))
print(a)
print(a.shape)        # 查看a的形状
a = a.unsqueeze(1)    # 等同于a = a.unsqueeze(-2)
print(a)              # 经过unsqueeze函数后的情况
print(a.shape)        # 维度变化情况

# 结果输出
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
tensor([[[0., 1., 2.]],

        [[3., 4., 5.]],

        [[6., 7., 8.]]], dtype=torch.float64)
torch.Size([3, 1, 3])

​
import torch
a = torch.arange(9,dtype=float).reshape((3,3))
print(a)
print(a.shape)        # 查看a的形状
a = a.unsqueeze(2)    # 等同于a = a.unsqueeze(-1)
print(a)              # 经过unsqueeze函数后的情况
print(a.shape)        # 维度变化情况

# 结果输出
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
tensor([[[0.],
         [1.],
         [2.]],

        [[3.],
         [4.],
         [5.]],

        [[6.],
         [7.],
         [8.]]], dtype=torch.float64)
torch.Size([3, 3, 1])

squeeze(减少长度为1的维度):

import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape)        # 查看s的形状
a = a.squeeze()       # 等同于a = a.unsqueeze(-1)         
print(a)              # 经过squeeze函数后的情况      
print(a.shape)        # 维度变化情况


# 结果
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
   
  
   
     
import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape)        # 查看s的形状
a = a.squeeze(0)      # 此时删除a的下标为0的维度长度为1         
print(a)              # 经过squeeze函数后的情况      
print(a.shape)        # 维度变化情况
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[[0., 1., 2.]],

        [[3., 4., 5.]],

        [[6., 7., 8.]]], dtype=torch.float64)
torch.Size([3, 1, 3])
import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape)        # 查看s的形状
a = a.squeeze(2)      # 此时删除a的下标为2的维度长度为1(等价于a = a.squeeze(-2))       
print(a)              # 经过squeeze函数后的情况      
print(a.shape)        # 维度变化情况

# 结果
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[[0., 1., 2.],
         [3., 4., 5.],
         [6., 7., 8.]]], dtype=torch.float64)
torch.Size([1, 3, 3])

import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape)        # 查看s的形状
a = a.squeeze(1)      # 此时小标为1或3时,对应的维度长度均为3,所以squeeze没有起作用        
print(a)              # 经过squeeze函数后的情况      
print(a.shape)        # 维度变化情况

# 结果
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])

  • 2
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值