文章目录
一、squeeze与unsqueeze函数
顾名思义,squeeze
函数的作用是压缩(降维);unsqueeze
函数的作用是解压(升维)。
1.squeeze函数
torch.squeeze(input, dim=None, *, out=None)
功能:
- 返回一个所有维度不为1的张量。
比如,输入张量维度为A×1×B×C×1×D
,则经squeeze
函数的输出张量维度为A×B×C×D
。 - 当
dim
参数设置时,只在所设置的dim
维度进行张量的压缩,如果所设置的dim
维度的尺寸不为1,则张量不会发生任何改变。
比如,输入张量的维度为A×1×B
,squeeze(input, 0)
函数不会使得输出张量的维度有任何改变。但是squeeze(input, 1)
函数会使得输出张量的维度变为A×B
。
需要注意的是:
- 返回的张量与输入张量共享内存,因此,其中一个张量的内容改变,另一个张量的内容也会改变
- 如果一个张量batch维度设置为1,那么经squeeze函数将会移除batch维度,这会导致错误
参数:
input(Tensor)
:输入张量dim(int, optional)
:如果这个参数被设置,则输入张量仅在这个维度(dim)上进行压缩
例子:
x = torch.zeros(2, 1, 2, 1, 2)
x.size()
=>torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x)
y.size()
=>torch.Size([2, 2, 2])
y = torch.squeeze(x, 0)
y.size()
=>torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 1)
y.size()
=>torch.Size([2, 2, 1, 2])
2.unsqueeze函数
torch.unsqueeze(input, dim)
功能
- 返回一个在指定位置插入尺寸为 1 的新张量
dim
参数的区间为[-input.dim() - 1, input.dim() + 1)
,负的dim
维度相当于dim = dim + input.dim() + 1
参数:
input (Tensor)
:输入张量dim (int)
:插入的索引
案例:
# shape=4
x = torch.tensor([1, 2, 3, 4])
# shape=(1,4)
torch.unsqueeze(x, 0)
=>tensor([[ 1, 2, 3, 4]])
# shape=(4,1)
torch.unsqueeze(x, 1)
=>t