Pytorch-squeeze()和unsqueeze()

1.squeeze(dim)

其中 dim 从 0 开始计数。

函数效果为将参数对应的维去掉,需满足该维值为1,

举个例子就是 view((1, 2, 3)) 中,只有 1 那个维才可以被 squeeze() 降维。

初始化一个三维的 Tensor :

# coding=utf-8

import torch

x = torch.arange(6).view((1, 2, 3))  # torch.Size([1, 2, 3])

使用 squeeze(0) 或 squeeze(-3) 对其降维:

temp1 = x.squeeze(0)  # torch.Size([2, 3])
temp2 = x.view((2, 3))  # torch.Size([2, 3])

print('temp1: {}\n temp2: {}'.format(temp1, temp2))

输出:
temp1: tensor([[0, 1, 2],
               [3, 4, 5]])
temp2: tensor([[0, 1, 2],
               [3, 4, 5]])

这里使用 view() 与其进行对比,其输出与 squeeze() 一致。

试一试对 x 使用 squeeze(1):

temp3 = x.squeeze(1)  # torch.Size([1, 2, 3])

print('x: {}\n temp3: {}'.format(x, temp3))

输出:
x: tensor([[[0, 1, 2],
            [3, 4, 5]]])
temp3: tensor([[[0, 1, 2],
                [3, 4, 5]]])

发现程序没有报错,且 temp3 的值与 x 的值一样,并没有降维的效果。

 

2.unsqueeze(dim)

其中 dim 从 0 开始计数。

与 squeeze() 对应的,unsqueeze() 则是增加维度。

初始化一个二维的 Tensor:

# coding=utf-8

import torch

x = torch.arange(6).view((2, 3))  # torch.Size([2, 3])

对其加入一个维度:

temp1 = x.unsqueeze(1)  # torch.Size([2, 1, 3])
temp2 = x.unsqueeze(-2)  # torch.Size([2, 1, 3])
temp3 = x.view((2, 1, 3))  # torch.Size([2, 1, 3])

print('temp1: {}\ntemp2: {}\n temp3: {}'.format(temp1, temp2, temp3))

输出:
temp1: tensor([[[0, 1, 2]],
               [[3, 4, 5]]])
temp2: tensor([[[0, 1, 2]],
               [[3, 4, 5]]])
temp3: tensor([[[0, 1, 2]],
               [[3, 4, 5]]])

可以看到与使用 view() 函数效果一致。

 

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值