前言
根据我的理解及有关资料记录一下pytorch中的squeeze()和unsqueeze()用法,方便以后查阅~
一、torch.unsqueeze()
torch.unsqueeze()用于数据扩维,在指定维度上添加一维,添加的维度大小为1,torch.unsqueeze()接收一个参数,用于指定在哪个维度上扩充。示例代码如下:
import torch
A = torch.randn(1,7)
B = A.unsqueeze(0)
C = A.unsqueeze(1)
D = A.unsqueeze(2)
print(B.shape)
print(C.shape)
print(D.shape)
输出如下:
torch.Size([1, 1, 7])
torch.Size([1, 1, 7])
torch.Size([1, 7, 1])
二、torch.squeeze()
与torch.unsqueeze()作用正好相反,torch.squeeze()用于对数据维度进行压缩。该函数接收一个参数,指定需要压缩的维度(只能压缩大小为1的维度)。若不指定参数,那么会压缩掉所有大小为1的维度。示例代码如下:
import torch
A = torch.randn(1,7)
D = A.unsqueeze(2).unsqueeze(3).unsqueeze(4) # 扩展第3维、第4维、第5维
E = D.squeeze(2) # 压缩第3维
F = D.squeeze() # 压缩所有维度为1的维
print(D.shape)
print(E.shape)
print(F.shape)
输出如下:
torch.Size([1, 7, 1, 1, 1])
torch.Size([1, 7, 1, 1])
torch.Size([7])