使用PyTorch 的 squeeze 方法将形状为 [1, 1, 5] 的张量压缩为形状为 [5] 的张量

文章介绍了PyTorch中的squeeze函数,用于删除张量中尺寸为1的维度。通过示例解释了如何使用squeeze()方法将形状为[1,1,5]的张量转换为[5],并讨论了当有多个维度为1以及指定特定维度压缩的情况。
摘要由CSDN通过智能技术生成

在这里插入图片描述


一、squeeze是什么?

在 PyTorch 中,squeeze() 是一个张量操作函数,用于删除张量中尺寸为 1 的维度,也就是将存在的维度中值为 1 的部分删掉。具体来说,squeeze() 函数将对张量进行操作并返回一个新的张量,新张量中删掉了所有尺寸为 1 的维度。如果张量某维度值不为 1,那么这个维度不会发生变化。如果指定了参数 dim,则只会在指定的轴向上进行操作。
我给举个简单的例子吧:

import torch

x = torch.randn(1,10,1,5)
print(x.shape) # 输出:torch.Size([1, 10, 1, 5])

y = torch.squeeze(x)
print(y.shape) # 输出:torch.Size([10, 5])

z = torch.squeeze(x,dim=2)
print(z.shape) # 输出:torch.Size([1, 10, 5])
在以上示例中,squeeze() 函数删除了第 1 和第 3 维度值为 1 的维度,新张量的形状为 [10, 5]。
在第三次使用 squeeze() 函数时,只删除了第三个维度值为 1 的维度,
新张量的形状为 [1, 10, 5],与原始张量形状相比并没有发生变化。

二、本文中的案例——使用 PyTorch 的 squeeze 方法将形状为 [1, 1, 5] 的张量压缩为形状为 [5] 的张量,具体代码如下:

import torch

# 创建形状为 [1, 1, 5] 的张量
x = torch.randn(1, 1, 5)
print("原始张量的形状:", x.shape)

# 使用 squeeze 方法将其压缩
x = torch.squeeze(x)
print("压缩后张量的形状:", x.shape)
原始张量的形状: torch.Size([1, 1, 5])
压缩后张量的形状: torch.Size([5])
我们可以注意到,在调用 squeeze 方法时,需要指定要压缩的维度,如果不指定,
则默认压缩所有长度为1的维度。在这个例子中,由于所有的维度长度都为1,
所以squeeze方法会将三个维度都压缩成一个维度,
形状为 [5] 的张量。如果原始张量的其他维度中存在长度不为1的维度,
则squeeze方法不会将其压缩,而是保留原始的张量形状。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值