PyTorch维度操控:squeeze与unsqueeze详解

PyTorch中的squeezeunsqueeze是用于改变张量维度的重要函数:

squeeze()函数

作用:移除张量中大小为1的维度

import torch

# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.randn(1, 3, 1, 4)
print(f"原始形状: {x.shape}")  # torch.Size([1, 3, 1, 4])

# 移除所有大小为1的维度
y = x.squeeze()
print(f"squeeze后: {y.shape}")  # torch.Size([3, 4])

# 指定移除特定维度
z = x.squeeze(0)  # 只移除第0维
print(f"squeeze(0)后: {z.shape}")  # torch.Size([3, 1, 4])

unsqueeze()函数

作用:在指定位置添加一个大小为1的维度

# 创建一个形状为 (3, 4) 的张量
a = torch.randn(3, 4)
print(f"原始形状: {a.shape}")  # torch.Size([3, 4])

# 在第0维添加维度
b = a.unsqueeze(0)
print(f"unsqueeze(0)后: {b.shape}")  # torch.Size([1, 3, 4])

# 在第1维添加维度
c = a.unsqueeze(1)
print(f"unsqueeze(1)后: {c.shape}")  # torch.Size([3, 1, 4])

# 在最后添加维度
d = a.unsqueeze(-1)
print(f"unsqueeze(-1)后: {d.shape}")  # torch.Size([3, 4, 1])

实际应用场景

1. 批处理维度处理

# 单个图像 (3, 224, 224) 需要添加batch维度
image = torch.randn(3, 224, 224)
batch_image = image.unsqueeze(0)  # (1, 3, 224, 224)

2. 广播运算

# 矩阵运算中调整维度以支持广播
x = torch.randn(3, 4)
y = torch.randn(4)

# y需要扩展维度才能与x相加
y_expanded = y.unsqueeze(0)  # (1, 4)
result = x + y_expanded  # 广播相加

3. 神经网络输出处理

# 模型输出后移除多余的维度
model_output = torch.randn(1, 10, 1)  # 批大小1,类别10,额外维度1
predictions = model_output.squeeze()  # (10,) 或 (1, 10)

这两个函数在深度学习中非常常用,特别是在处理不同维度的数据和确保张量形状匹配时。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值