Python-squeeze()、unsqueeze()函数的理解

本文详细介绍了PyTorch中张量的维度操作方法,重点讲解了squeeze()和unsqueeze()函数的功能及使用方法,并通过实例演示如何进行降维和增维操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

1. 降维torch.squeeze(input, dim=None, out=None)

简单示例

matplotlib画图示例

2.增维 torch.unsqueeze(input, dim, out=None)

简单示例

3.参考


1. 降维torch.squeeze(input, dim=None, out=None)

函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。

  • 当给定dim时,那么挤压操作只在给定维度上。即若tensor.size(dim) = 1,则去掉该维度
    • 其中squeeze(0)代表若第一维度值为1则去除第一维度
    • squeeze(1)代表若第二维度值为1则去除第二维度
    • -1,去除最后维度值为1的维度
  • 当不给定dim时,将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D)(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)(A×B×C×D)

例如,输入形状为: (A×1×B)(A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)(A×B)。

注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

参数:

  • input (Tensor) – 输入张量
  • dim (int, optional) – 如果给定,则input只会在给定维度挤压,维度的索引(从0开始)
  • out (Tensor, optional) – 输出张量

简单示例

a = torch.Tensor(1,3)
>>
tensor([[-1.37,4.56,-3.57]])

print a.squeeze(0) #第一个维度大小确实是1,所以可以去除
>>
tensor([-1.37,4.56,-3.57])

print a.squeeze(1) ##第二个维度大小是3,所以不能去除
>>
tensor([[-1.37,4.56,-3.57]])

#例子2
b = torch.Tensor(2,3)
print b
>>
tensor([[-3.17,3.09,1.43],
[0.00,0.00,0.00]])

print b.squeeze(0)##第一个维度大小不是1,所以不能去除
>>
tensor([[-3.17,3.09,1.43],
[0.00,0.00,0.00]])

print b.squeeze(1) ##第二个维度大小是3,所以不能去除
>>
tensor([[-3.17,3.09,1.43],
[0.00,0.00,0.00]])

#例子3
c = torch.Tensor(3,1)
print c
>>
tensor([[-3.54],
[3.09],
[0.00]])

print c.squeeze(0)##第一个维度大小不是1,所以不能去除
>>
tensor([[-3.54],
[3.09],
[0.00]])

print c.squeeze(1)#第二个维度大小确实是1,所以可以去除
>>
tensor([-3.54,3.09,0.00])

matplotlib画图示例

import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
#无法正常显示图示案例
squares =np.array([[1,4,9,16,25]]) 
squares.shape       #要显示的数组为可表示1行5列的向量的数组
(1, 5)
plt.plot(squares)
plt.show()

 

#正常显示图示案例
#通过np.squeeze()函数转换后,要显示的数组变成了秩为1的数组,即(5,)
plt.plot(np.squeeze(squares))    
plt.show()

 

np.squeeze(squares).shape
(5,)

 


2.增维 torch.unsqueeze(input, dim, out=None)

增加大小为1的维度,也就是返回一个新的张量,对输入的指定位置插入维度 1且必须指明维度

  • x = torch.unsqueeze(x, 3) # 在第3个维度上扩展

注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

如果dim为负,则将会被转化dim+input.dim()+1例如对于一个(3,2,4)的tensor,其dim可以选择为none,-1,0,1,2

  • none:所有元素的max,得到一个max值
  • -1:若dim为负,则将被转化为dim+input.dim()+1,即2
    • -1+2+1
  • 0:最粗粒度的方向,在第1维插入一个维度
  • 1:在第2维插入一个维度
  • 2:最细粒度的方向,在第3维插入一个维度
  • -3:在倒数第3维插入一个维度,在本例子也就是第一维
  • 一句话概括:dim越大,越深入,none即所有最小元素参与计算。

参数:

  • tensor (Tensor) – 输入张量
  • dim (int) – 插入维度的索引(从0开始)
  • out (Tensor, optional) – 结果张量

简单示例

import torch
 
x = torch.ones(4)
print(x)
print(x.size())
 
y = torch.unsqueeze(x, 0)
print(y)
print(y.size())
 
z = torch.unsqueeze(x, 1)
print(z)
print(z.size())

结果

tensor([1., 1., 1., 1.])
torch.Size([4])
tensor([[1., 1., 1., 1.]])
torch.Size([1, 4])
tensor([[1.],
        [1.],
        [1.],
        [1.]])
torch.Size([4, 1])

 分析

插入维度之前:

[ 1, 1, 1, 1 ]

在第0维插入一个维度,使其变成(1,4),即在最外层插入一个中括号即可:

[ [ 1, 1, 1, 1 ] ]

在第1维插入一个维度,使其变成(4,1)

[ [1], [1], [1], [1] ]

3.参考

pytorch中对维度及其squeeze()、unsqueeze()函数的理解

torch.squeeze()和unsqueeze()

Numpy库学习—squeeze()函数

 

 

<think>好的,我现在需要回答用户关于PyTorch中unsqueeze函数的使用说明和示例的问题。首先,我需要回忆一下unsqueeze函数的基本功能。根据用户提供的引用内容,比如引用[2]和引用[3],unsqueeze函数用于在张量的指定位置插入一个维度,也就是增加一个维度。例如,一个形状为(3,5)的张量,在第0维使用unsqueeze(0)后,形状变为(1,3,5)。这点需要明确说明。 接下来,我需要整理使用说明。根据引用[3]和引用[5],unsqueeze有两种调用方式,一种是torch.unsqueeze(input, dim),另一种是张量自带的x.unsqueeze(dim)方法。要指出dim参数可以是正数或负数,负数表示从后往前的位置,比如-1表示最后一个维度之后。需要注意的是,dim的取值范围应在[-input.dim()-1, input.dim()]之间,否则会报错。这部分需要详细说明,避免用户使用时出错。 然后是示例部分。用户提供的引用[3]中有一个示例,展示了在第0维添加维度。我应该补充更多的例子,比如在不同维度上使用正负索引,以及展示添加维度后的形状变化。例如,一个形状为(2,3)的张量,在dim=1时变为(2,1,3),dim=-1时变为(2,3,1)。这样可以让用户更直观地理解不同dim参数的效果。 此外,还需要提到unsqueeze的常见应用场景,比如数据预处理、广播机制和模型输入维度调整。根据引用[4]中的张量操作,可以举一个数据预处理的例子,比如需要将图像数据从(高度, 宽度)变为(批次大小, 通道数, 高度, 宽度),这时可以用unsqueeze添加批次和通道维度。同时,广播机制中,如果两个张量维度不匹配,可以用unsqueeze来调整,方便进行运算。例如,形状(3,)和(3,3)的张量相加,需要将前者变为(3,1)才能广播。 还要注意可能出现的错误,比如dim超出范围,或者多次使用unsqueeze导致维度过多。这部分需要提醒用户注意,避免常见错误。 最后,生成相关问题。根据用户提供的示例问题,应该围绕unsqueeze的相关函数(如squeeze)、常见错误、在模型中的应用等来提问。例如,如何移除维度,使用unsqueeze时可能出现的错误,以及与view函数的区别等。 在整理回答时,要确保所有数学表达式和代码示例符合格式要求,行内公式用$...$,独立公式用$$...$$,代码块用三个反引号包裹。同时,引用用户提供的参考文献,比如引用[3]中的示例,末尾添加引用标识[^3]。 现在需要检查是否有遗漏的信息,比如函数参数的具体说明,是否提到原地操作(in-place)的问题。根据PyTorch的知识,unsqueeze不会原地修改张量,而是返回一个新的张量,这点需要说明,避免用户误解。 总结一下,回答的结构应该是:功能简介、使用方法(两种调用方式、参数说明)、示例(不同dim的情况)、应用场景、注意事项、相关问题。确保每个部分清晰,示例明确,并且引用正确。</think>### PyTorch中`unsqueeze`函数使用说明 `unsqueeze`函数用于在张量的指定维度插入一个大小为1的新维度,常用于调整张量形状以满足运算或模型输入的要求。 --- #### 一、功能简介 1. **作用**:在指定位置插入维度大小为1的轴,扩展张量维度[^2][^3]。 2. **数学表达**:若输入张量形状为$(d_0, d_1, ..., d_{n-1})$,在维度$k$使用`unsqueeze`后,形状变为$(d_0, ..., 1, ..., d_{n-1})$,其中$1$插入在位置$k$。 --- #### 二、使用方法 ##### 1. 函数语法 - **方式1**:`torch.unsqueeze(input, dim)` - **方式2**:`tensor.unsqueeze(dim)`(张量对象直接调用)[^5] ##### 2. 参数说明 - `dim`:指定插入新维度的位置,支持正负索引: - 正数:从前往后计数(0-based) - 负数:从后往前计数(如-1表示最后一个维度之后) - **取值范围**:$[-input.dim()-1, input.dim()]$,超出会报错[^3]。 --- #### 三、示例代码 ##### 示例1:基础用法 ```python import torch x = torch.tensor([[1, 2], [3, 4]]) # 形状:(2, 2) y = x.unsqueeze(0) # 在dim=0插入维度 print(y.shape) # 输出:torch.Size([1, 2, 2]) ``` ##### 示例2:负索引 ```python z = x.unsqueeze(-1) # 在最后一个维度后插入 print(z.shape) # 输出:torch.Size([2, 2, 1]) ``` ##### 示例3:多维度扩展 ```python w = x.unsqueeze(1).unsqueeze(-1) # 依次在dim=1和dim=-1插入 print(w.shape) # 输出:torch.Size([2, 1, 2, 1]) ``` --- #### 四、应用场景 1. **数据预处理**:调整输入维度以满足模型要求。 ```python image = torch.randn(28, 28) # 单张图像形状 batch_image = image.unsqueeze(0) # 添加批次维度 → (1, 28, 28) ``` 2. **广播机制**:对齐张量维度以支持运算。 ```python a = torch.tensor([1, 2, 3]) # 形状:(3,) b = torch.randn(3, 3) c = a.unsqueeze(1) + b # 广播为(3,1)+(3,3)→(3,3) ``` 3. **模型输入**:适配卷积层输入(需4D张量:批次×通道×高×宽)[^4]。 --- #### 五、注意事项 1. **非原地操作**:`unsqueeze`返回新张量,原张量不变[^5]。 2. **维度限制**:避免重复插入过多维度(如多次调用可能导致形状混乱)。 3. **与`squeeze`配合**:可用`squeeze()`移除大小为1的维度。 ---
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

静静喜欢大白

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值