Pytorch基础:Tensor的squeeze和unsqueeze方法

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


        在Pytorch中,squeeze和unsqueeze是Tensor类的一个重要方法,同时它们也是torch模块中的一个函数,它们的语法如下所示。 

Tensor.squeeze(dim=None) → Tensor
torch.squeeze(input, dim=None) → Tensor

input (Tensor) – the input tensor.
dim (int or tuple of ints, optional) – if given, the input will be squeezed only in the specified dimensions.


Tensor.unsqueeze(dim) → Tensor
torch.unsqueeze(input, dim) → Tensor

input (Tensor) – the input tensor.
dim (int) – the index at which to insert the singleton dimension

一、squeeze

        squeeze函数(或方法)返回一个新的张量,该张量移除了原张量中大小为1的维度,例如:输入张量的形状是(A×1×B×C×1×D),使用了squeeze函数(或方法)后,输出张量的形状是(A×B×C×D)。请注意:输出张量将与输入张量共享底层存储,因此改变一个张量的内容将改变另一个张量的内容。默认情况下,squeeze将移除所有尺寸为1的维度,如果传递了dim参数,则会将dim中的维度展开。dim的范围可以是[-input.dim()-1, input.dim()],其中负数索引表示从后往前数的位置,例如-1代表最后一个维度。

        可以看下面的例子以更好的理解:

import torch

# 创建一个形状为 (2, 1, 2, 1, 2) 的张量
x = torch.zeros(2, 1, 2, 1, 2)
print(x, x.size(), id(x))

# 移除所有大小为1的维度
a = torch.squeeze(x)  # 等价于 a = x.squeeze()
print(a, a.size(), id(a))

# 尝试移除第0维度(由于第0维度大小不为1,因此不改变形状)
b = torch.squeeze(x, 0)  # 等价于 b = x.squeeze(0)
print(b, b.size(), id(b))

# 移除第1维度(第1维度大小为1)
c = torch.squeeze(x, 1)  # 等价于 c = x.squeeze(1)
print(c, c.size(), id(c))

# 移除第1、第2和第3维度(第1和第3维度大小为1,第2维度不变)
d = torch.squeeze(x, (1, 2, 3))  # 等价于 d = x.squeeze((1, 2, 3))
print(d, d.size(), id(d))

# 验证所有张量共享底层存储空间
print(x.storage().data_ptr() == a.storage().data_ptr() == b.storage().data_ptr() == c.storage().data_ptr() == d.storage().data_ptr()) # 共享底层存储空间


输出:
tensor([[[[[0., 0.]],
          [[0., 0.]]]],
        [[[[0., 0.]],
          [[0., 0.]]]]]) torch.Size([2, 1, 2, 1, 2]) 1899057117680

tensor([[[0., 0.],
         [0., 0.]],
        [[0., 0.],
         [0., 0.]]]) torch.Size([2, 2, 2]) 1899057158240

tensor([[[[[0., 0.]],
          [[0., 0.]]]],
        [[[[0., 0.]],
          [[0., 0.]]]]]) torch.Size([2, 1, 2, 1, 2]) 1899737467296

tensor([[[[0., 0.]],
         [[0., 0.]]],
        [[[0., 0.]],
         [[0., 0.]]]]) torch.Size([2, 2, 1, 2]) 1899737467376

tensor([[[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]]]) torch.Size([2, 2, 2]) 1899737467216
True

二、 unsqueeze

        unsqueeze函数(或方法)函数返回一个新的张量,该张量在指定维度(dim)插入一个大小为1的维度。使用unsqueeze函数(或方法)后,输入张量的形状会相应增加一个维度。例如,输入张量的形状是(A×B×C),在第1维度使用unsqueeze后,输出张量的形状将变为(A×1×B×C)。请注意,输出张量将与输入张量共享底层存储,因此改变一个张量的内容将改变另一个张量的内容。dim的范围可以是[-input.dim(), input.dim()-1],其中负数索引表示从后往前数的位置,例如-1代表最后一个维度。

         可以看下面的例子以更好的理解:

import torch

# 创建一个形状为 (2, 2, 2) 的张量
x = torch.zeros(2, 2, 2)
print(x, x.size(), id(x))

# 在第0维度插入单维度
a = torch.unsqueeze(x, 0)  # 等价于 a = x.unsqueeze(0)
print(a, a.size(), id(a))

# 在第1维度插入单维度
b = torch.unsqueeze(x, 1)  # 等价于 b = x.unsqueeze(1)
print(b, b.size(), id(b))

# 在第2维度插入单维度
c = torch.unsqueeze(x, 2)  # 等价于 c = x.unsqueeze(2)
print(c, c.size(), id(c))

# 在第3维度插入单维度
d = torch.unsqueeze(x, 3)  # 等价于 d = x.unsqueeze(3)
print(d, d.size(), id(d))

# 验证所有张量共享底层存储空间
print(x.storage().data_ptr() == a.storage().data_ptr() == b.storage().data_ptr() == c.storage().data_ptr() == d.storage().data_ptr())  # 共享底层存储空间


输出:
tensor([[[0., 0.],
         [0., 0.]],
        [[0., 0.],
         [0., 0.]]]) torch.Size([2, 2, 2]) 1509028592032

tensor([[[[0., 0.],
          [0., 0.]],
         [[0., 0.],
          [0., 0.]]]]) torch.Size([1, 2, 2, 2]) 1509028632592

tensor([[[[0., 0.],
          [0., 0.]]],
        [[[0., 0.],
          [0., 0.]]]]) torch.Size([2, 1, 2, 2]) 1507561225888

tensor([[[[0., 0.]],
         [[0., 0.]]],
        [[[0., 0.]],
         [[0., 0.]]]]) torch.Size([2, 2, 1, 2]) 1507561391824

tensor([[[[0.],
          [0.]],
         [[0.],
          [0.]]],
        [[[0.],
          [0.]],
         [[0.],
          [0.]]]]) torch.Size([2, 2, 2, 1]) 1507561391904
True

三、squeeze_和unsqueeze_

        Tensor还有拥有squeeze_和unsqueeze_方法,它们是原地操作(in-place)版本的,意思是它们会直接对原始张量进行操作,而不是创建新的张量。squeeze_和unsqueeze_的语法和功能与squeeze和unsqueeze类似,只是它们会直接修改调用它们的张量本身,除此之外它返回的也是修改后的变量本身(因此无需通过返回值传递常量)。由于是原地操作,squeeze_和unsqueeze_会节省内存,但是要小心使用,因为它们会改变原始数据。以下是一些例子,帮助理解这两个原地操作方法的使用:

import torch

# 创建一个形状为 (2, 1, 2, 1, 2) 的张量
x = torch.zeros(2, 1, 2, 1, 2)
print(x, x.size(), id(x))

# 原地移除所有大小为1的维度
x.squeeze_()  # 等价于 x = x.squeeze_()
print(x, x.size(), id(x))

# 重新创建一个形状为 (2, 1, 2, 1, 2) 的张量
x = torch.zeros(2, 1, 2, 1, 2)
print(x, x.size(), id(x))

# 原地移除第1维度(第1维度大小为1)
x.squeeze_(1)  # 等价于 x = x.squeeze_(1)
print(x, x.size(), id(x))

# 重新创建一个形状为 (2, 2, 2) 的张量
x = torch.zeros(2, 2, 2)
print(x, x.size(), id(x))

# 原地在第1维度插入单维度
x.unsqueeze_(1)  # 等价于 x = x.unsqueeze_(1)
print(x, x.size(), id(x))

输出:
tensor([[[[[0., 0.]],
          [[0., 0.]]]],
        [[[[0., 0.]],
          [[0., 0.]]]]]) torch.Size([2, 1, 2, 1, 2]) 140209009511456

tensor([[[0., 0.],
         [0., 0.]],
        [[0., 0.],
         [0., 0.]]]) torch.Size([2, 2, 2]) 140209009511456

tensor([[[[[0., 0.]],
          [[0., 0.]]]],
        [[[[0., 0.]],
          [[0., 0.]]]]]) torch.Size([2, 1, 2, 1, 2]) 140204776090096

tensor([[[[0., 0.]],
         [[0., 0.]]],
        [[[0., 0.]],
         [[0., 0.]]]]) torch.Size([2, 2, 1, 2]) 140204776090096

tensor([[[0., 0.],
         [0., 0.]],
        [[0., 0.],
         [0., 0.]]]) torch.Size([2, 2, 2]) 140209009511455

tensor([[[[0., 0.],
          [0., 0.]]],
        [[[0., 0.],
          [0., 0.]]]]) torch.Size([2, 1, 2, 2]) 140209009511455
### PyTorchTensor 的 `squeeze`、`reshape` `unsqueeze` 函数详解 #### 1. `torch.squeeze()` 函数 `squeeze()` 函数用于移除张量中尺寸为 1 的维度。如果指定了某个特定的维度,则仅当该维度大小为 1 时才会被移除;如果没有指定任何维度参数,则会自动移除所有尺寸为 1 的维度。 - **语法**: ```python torch.squeeze(input, dim=None) ``` - **功能说明**: - 如果不提供 `dim` 参数,那么所有尺寸为 1 的维度都会被删除。 - 如果提供了 `dim` 参数,只有这个维度上的大小为 1 才会被删除,其他情况下不会有任何改变。 - **示例代码**: ```python import torch a = torch.tensor([[[[1], [2], [3]]]]) print("原始张量:", a) print("原始形状:", a.shape) # 输出: (1, 3, 1) b = torch.squeeze(a) print("经过 squeeze 后的张量:", b) print("经过 squeeze 后的形状:", b.shape) # 输出: (3,) c = torch.squeeze(a, dim=0) print("沿第 0 维度挤压后的张量:", c) print("沿第 0 维度挤压后的形状:", c.shape) # 输出: (3, 1) ``` --- #### 2. `torch.unsqueeze()` 函数 `unsqueeze()` 函数的作用是在指定位置增加一个新的维度,新维度的大小默认为 1。 - **语法**: ```python torch.unsqueeze(input, dim) ``` - **功能说明**: - 新增的维度由 `dim` 参数决定。 - 增加的新维度可以位于任意位置(前、中间或后)。 - **示例代码**: ```python d = torch.tensor([1, 2, 3]) print("原始张量:", d) print("原始形状:", d.shape) # 输出: (3,) e = torch.unsqueeze(d, dim=0) print("新增第 0 维度后的张量:", e) print("新增第 0 维度后的形状:", e.shape) # 输出: (1, 3) f = torch.unsqueeze(d, dim=1) print("新增第 1 维度后的张量:", f) print("新增第 1 维度后的形状:", f.shape) # 输出: (3, 1) ``` --- #### 3. `torch.reshape()` 函数 `reshape()` 函数允许重新定义张量的形状,只要满足总元素数量不变即可。 - **语法**: ```python input.reshape(shape) ``` - **功能说明**: - 可以通过 `-1` 自动推导某一个维度的大小。 - 不同于 `.view()` 方法,`.reshape()` 支持跨内存布局的操作,因此即使原张量不是连续存储也可以成功调用。 - **示例代码**: ```python g = torch.arange(6) print("原始张量:", g) print("原始形状:", g.shape) # 输出: (6,) h = g.reshape(2, 3) print("重塑后的张量:", h) print("重塑后的形状:", h.shape) # 输出: (2, 3) i = g.reshape(-1, 2) print("使用 -1 推导后的张量:", i) print("使用 -1 推导后的形状:", i.shape) # 输出: (3, 2) ``` --- ### 总结对比 | 功能 | 描述 | |--------------|----------------------------------------------------------------------------------------| | `squeeze()` | 移除尺寸为 1 的维度[^1] | | `unsqueeze()`| 在指定位置插入新的维度,其大小固定为 1[^5] | | `reshape()` | 更改张量的整体结构,前提是保持总的元素数目一致[^3] | 上述三个函数均适用于调整张量的维度,在实际应用中可以根据需求灵活组合它们来完成复杂的矩阵变换操作。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

日晨难再

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值