PyTorch学习之 torch.unsqueeze 函数
一、功能
torch.unsqueeze
是向张量的指定位置插入一个尺寸为1的新维度,从而扩展张量的形状
。
二、基本语法
torch.unsqueeze(data, dim)
三、参数说明
data
(Tensor): 输入的张量。dim
(int): 指定插入新维度的位置。该值必须在[-data.dim() - 1, data.dim() + 1)
范围内。- 如果
dim<0
,那么实际操作时,插入的维度:dim=dim + data.dim() + 1
。其中data.dim()
返回数据的维度
四、返回值
返回一个新的张量,在指定位置插入了一个尺寸为1的新维度。
⚠️⚠️⚠️返回的张量与输入张量共享内存,因此改变其中一个张量的值会影响另一个。
五、示例
实际使用过程中我们操作的对象可能本身就是张量类型(tensor
)了,而tensor类型数据对象自带unsequeeze
方法。
因此下面的实例中,我们都直接用数据对象调用其unsequeeze方法的形式展示给大家。
示例 1: 向0维插入新维度
import torch
# 创建一个形状为 (3, 4) 的张量
x = torch.randn(3, 4)
print("原始张量形状:", x.shape)
# 在0维插入新维度
y = x.unsqueeze(0)
# y = x.unsqueeze(dim=0) # 等价,但是不如上面好用
print("在0维插入新维度后的张量形状:", y.shape)
输出:
原始张量形状: torch.Size([3, 4])
在0维插入新维度后的张量形状: torch.Size([1, 3, 4])
示例 2: 向1维插入新维度
import torch
x = torch.randn(3, 4)
print("原始张量形状:", x.shape)
y = x.unsqueeze(1)
print("在1维插入新维度后的张量形状:", y.shape)
输出:
原始张量形状: torch.Size([3, 4])
在1维插入新维度后的张量形状: torch.Size([3, 1, 4])
示例 3: 向最后一维插入新维度
import torch
x = torch.randn(3, 4)
print("原始张量形状:", x.shape)
y = x.unsqueeze(-1)
print("在最后一维插入新维度后的张量形状:", y.shape)
输出:
原始张量形状: torch.Size([3, 4])
在最后一维插入新维度后的张量形状: torch.Size([3, 4, 1])