PyTorch中的squeeze
和unsqueeze
是用于改变张量维度的重要函数:
squeeze()函数
作用:移除张量中大小为1的维度
import torch
# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.randn(1, 3, 1, 4)
print(f"原始形状: {x.shape}") # torch.Size([1, 3, 1, 4])
# 移除所有大小为1的维度
y = x.squeeze()
print(f"squeeze后: {y.shape}") # torch.Size([3, 4])
# 指定移除特定维度
z = x.squeeze(0) # 只移除第0维
print(f"squeeze(0)后: {z.shape}") # torch.Size([3, 1, 4])
unsqueeze()函数
作用:在指定位置添加一个大小为1的维度
# 创建一个形状为 (3, 4) 的张量
a = torch.randn(3, 4)
print(f"原始形状: {a.shape}") # torch.Size([3, 4])
# 在第0维添加维度
b = a.unsqueeze(0)
print(f"unsqueeze(0)后: {b.shape}") # torch.Size([1, 3, 4])
# 在第1维添加维度
c = a.unsqueeze(1)
print(f"unsqueeze(1)后: {c.shape}") # torch.Size([3, 1, 4])
# 在最后添加维度
d = a.unsqueeze(-1)
print(f"unsqueeze(-1)后: {d.shape}") # torch.Size([3, 4, 1])
实际应用场景
1. 批处理维度处理
# 单个图像 (3, 224, 224) 需要添加batch维度
image = torch.randn(3, 224, 224)
batch_image = image.unsqueeze(0) # (1, 3, 224, 224)
2. 广播运算
# 矩阵运算中调整维度以支持广播
x = torch.randn(3, 4)
y = torch.randn(4)
# y需要扩展维度才能与x相加
y_expanded = y.unsqueeze(0) # (1, 4)
result = x + y_expanded # 广播相加
3. 神经网络输出处理
# 模型输出后移除多余的维度
model_output = torch.randn(1, 10, 1) # 批大小1,类别10,额外维度1
predictions = model_output.squeeze() # (10,) 或 (1, 10)
这两个函数在深度学习中非常常用,特别是在处理不同维度的数据和确保张量形状匹配时。