1. squeeze()
函数形式如下:
torch.squeeze(input, dim=None) → Tensor
作用是对tensor变量进行降维,也就是维度压缩。但是有前提条件,就是被压缩的tensor中存在大小为1的维度,也就是他只能压缩冗余的维度,并保持数据内容不变。
如果不是指定参数中的dim,也就是维度,那么函数默认将压缩所有大小为1的维度,如:
import torch
a = torch.randn(2, 1, 2, 1, 2)
b = a.squeeze() # 也可以写成b = torch.squeeze(a) (官方文档写法)
print(b.shape) # 输出 torch.Size([2, 2, 2])
如果传入dim参数,那么squeeze()只对指定维度进行压缩(前提是大小为1)。
import torch
a = torch.randn(5, 1, 4)
print(a.squeeze(1).shape) # 输出 torch.Size([5, 4]),压缩了第1维
b = torch.randn(5, 4)
print(b.squeeze(1).shape) # 输出 torch.Size([5, 4]),没有压缩
2. unsqueeze()
函数形式如下:
torch.unsqueeze(input, dim) → Tensor
作用是对tensor变量,在参数dim指定的维度上进行扩充。
import torch
a = torch.randn(5, 4)
b = a.unsqueeze(0) # 也可以写成b = torch.unsqueeze(a, 0)(官方文档写法)
print(b.shape) # 输出torch.Size([1, 5, 4])
在一些模型中进行tensor之间的运算时,进行维度扩充很有用。当然在pytorch中存在广播机制(Broadcast),如果某个方法能够Broadcast,那么方法会将参数中的tensor自动进行维度变化以满足运算要求,例如torch.matmul().