【pytorch】torch.unsqueeze() 、torch.squeeze()

1. torch.unsqueeze 详解

torch.unsqueeze(input, dim, out=None)

  • 作用扩展维度

返回一个新的张量,对输入的既定位置插入维度 1

  • 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
如果dim为负,则将会被转化dim+input.dim()+1
  • 参数:
  • tensor (Tensor) – 输入张量
  • dim (int) – 插入维度的索引
  • out (Tensor, optional) – 结果张量
import torch

# 创建一个二维张量(矩阵)
x = torch.tensor([[1, 2], [3, 4]])
print(x.size())

x维度是 2 * 2 

# 在第0维增加一个维度
y = torch.unsqueeze(x, 0)
print(y.size())

在第0维增加了一个维度,y维度是 1 * 2 * 2 

  

2. unsqueeze_和 unsqueeze 的区别

PyTorch中的 XXX_ 和 XXX 实现的功能都是相同的,唯一不同的是前者进行的是 in_place 操作。

unsqueeze_ 和 unsqueeze 实现一样的功能,区别在于 unsqueeze_ 是 in_place 操作,即 unsqueeze 不会对使用 unsqueeze 的 tensor 进行改变,想要获取 unsqueeze 后的值必须赋予个新值, unsqueeze_ 则会对自己改变

print("-" * 50)
a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])

b = torch.unsqueeze(a, 1)
print(b)
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

print(a)
# tensor([1., 2., 3., 4.])


print("-" * 50)
a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])

print(a.unsqueeze_(1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

print(a)
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

3. torch.squeeze 详解

  • 作用降维

torch.squeeze(input, dim=None, out=None)

将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

  • 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
  • 参数:
  • input (Tensor) – 输入张量
  • dim (int, optional) – 如果给定,则input只会在给定维度挤压
  • out (Tensor, optional) – 输出张量

为何只去掉 1 呢?

多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。

import torch


m = torch.zeros(2, 1, 2, 1, 2)
print(m.size())  # torch.Size([2, 1, 2, 1, 2])

# 去掉所有大小为1的维度
n = torch.squeeze(m)
print(n.size())  # torch.Size([2, 2, 2])

# 去掉指定维度上大小为1的维度,如果指定的维度不是1,则不会对该维度进行操作
n = torch.squeeze(m, 0)  # 当给定dim时,那么挤压操作只在给定维度上
print(n.size())  # torch.Size([2, 1, 2, 1, 2])

# 去掉指定维度上大小为1的维度,如果指定的维度是1,则会对该维度进行删除
n = torch.squeeze(m, 1)
print(n.size())  # torch.Size([2, 2, 1, 2])

当然你也可以使用:

x.squeeze() 来替换 torch.squeeze(x)

torch.unsqueeze() 和 torch.squeeze() - 知乎

  • 2
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

飞越太平洋.

太谢谢了

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值