Pytorch中squeeze的使用

 squeeze

import pytorch
torch.squeeze(input, dim=None)

squeeze是一个用来压缩维度的函数,input为指定的输入的张量,dim这个参数为输入参量中需要压缩的维度。举个栗子

x=torch.ones(1,2,5)
print(x)

'''
输出:
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])
'''

列出维度为(1,2,5)的三维张量。这里解释一下这个三维如何看,首先三维对应着三个中括号。拆掉第一层(最外面一层)中括号,会发现里面只有一个数据(一个中括号)。

拆掉第二层中括号,会发现里面有两个数据(两个中括号)。

拆掉第三层其中一个中括号,会发现有五个数据(1.)

正好对应(1,2,5)维。

print(x.squeeze(0))
'''
对上述张量操作
输出:
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
'''

print(x.squeeze(1))
'''
对上述张量操作
输出:
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])
'''

squeeze(n)本质上是将第n维的维数由1转化成0,若第n维不为1,则不进行变化。

第一个例子中。对(1,2,5)维张量,squeeze(0)将第一个维度进行压缩。由于第一个维度为1。则输出转化为(2,5)维的张量。

第二个例子中。对(1,2,5)维张量,squeeze(1)将第二个维度进行压缩。由于第二个维度是2,不为1,不进行变化。输出不变。

unsqueeze

print(x.unsqueeze(2))
'''
tensor([[[[1., 1., 1., 1., 1.]],

         [[1., 1., 1., 1., 1.]]]])
'''
print(x.unsqueeze(3))
'''
tensor([[[[1.],
          [1.],
          [1.],
          [1.],
          [1.]],

         [[1.],
          [1.],
          [1.],
          [1.],
          [1.]]]])
'''

unsqueeze(n)表示的是在第n维度增加一个大小为1的维度

第一个例子中转化成(1,2,1,5)维

第二个例子转化成(1,2,5,1)维

菜鸟刚学,仅记录学习心得,欢迎交流~

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值