PyTorch学习之 torch.unsqueeze 函数

PyTorch学习之 torch.unsqueeze 函数

一、功能

torch.unsqueeze 是向张量的指定位置插入一个尺寸为1的新维度,从而扩展张量的形状

二、基本语法

torch.unsqueeze(data, dim)

三、参数说明

  • data (Tensor): 输入的张量。
  • dim (int): 指定插入新维度的位置。该值必须在 [-data.dim() - 1, data.dim() + 1) 范围内。
  • 如果dim<0,那么实际操作时,插入的维度:dim=dim + data.dim() + 1。其中data.dim()返回数据的维度

四、返回值

返回一个新的张量,在指定位置插入了一个尺寸为1的新维度。

⚠️⚠️⚠️返回的张量与输入张量共享内存,因此改变其中一个张量的值会影响另一个。

五、示例

实际使用过程中我们操作的对象可能本身就是张量类型(tensor)了,而tensor类型数据对象自带unsequeeze方法。

因此下面的实例中,我们都直接用数据对象调用其unsequeeze方法的形式展示给大家。

示例 1: 向0维插入新维度
import torch

# 创建一个形状为 (3, 4) 的张量
x = torch.randn(3, 4)
print("原始张量形状:", x.shape)

# 在0维插入新维度
y = x.unsqueeze(0)
# y = x.unsqueeze(dim=0) # 等价,但是不如上面好用
print("在0维插入新维度后的张量形状:", y.shape)

输出:

原始张量形状: torch.Size([3, 4])
在0维插入新维度后的张量形状: torch.Size([1, 3, 4])
示例 2: 向1维插入新维度
import torch

x = torch.randn(3, 4)
print("原始张量形状:", x.shape)

y = x.unsqueeze(1)
print("在1维插入新维度后的张量形状:", y.shape)

输出:

原始张量形状: torch.Size([3, 4])
在1维插入新维度后的张量形状: torch.Size([3, 1, 4])
示例 3: 向最后一维插入新维度
import torch

x = torch.randn(3, 4)
print("原始张量形状:", x.shape)

y = x.unsqueeze(-1)
print("在最后一维插入新维度后的张量形状:", y.shape)

输出:

原始张量形状: torch.Size([3, 4])
在最后一维插入新维度后的张量形状: torch.Size([3, 4, 1])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值