1.squeeze(dim)
其中 dim 从 0 开始计数。
函数效果为将参数对应的维去掉,需满足该维值为1,
举个例子就是 view((1, 2, 3)) 中,只有 1 那个维才可以被 squeeze() 降维。
初始化一个三维的 Tensor :
# coding=utf-8
import torch
x = torch.arange(6).view((1, 2, 3)) # torch.Size([1, 2, 3])
使用 squeeze(0) 或 squeeze(-3) 对其降维:
temp1 = x.squeeze(0) # torch.Size([2, 3])
temp2 = x.view((2, 3)) # torch.Size([2, 3])
print('temp1: {}\n temp2: {}'.format(temp1, temp2))
输出:
temp1: tensor([[0, 1, 2],
[3, 4, 5]])
temp2: tensor([[0, 1, 2],
[3, 4, 5]])
这里使用 view() 与其进行对比,其输出与 squeeze() 一致。
试一试对 x 使用 squeeze(1):
temp3 = x.squeeze(1) # torch.Size([1, 2, 3])
print('x: {}\n temp3: {}'.format(x, temp3))
输出:
x: tensor([[[0, 1, 2],
[3, 4, 5]]])
temp3: tensor([[[0, 1, 2],
[3, 4, 5]]])
发现程序没有报错,且 temp3 的值与 x 的值一样,并没有降维的效果。
2.unsqueeze(dim)
其中 dim 从 0 开始计数。
与 squeeze() 对应的,unsqueeze() 则是增加维度。
初始化一个二维的 Tensor:
# coding=utf-8
import torch
x = torch.arange(6).view((2, 3)) # torch.Size([2, 3])
对其加入一个维度:
temp1 = x.unsqueeze(1) # torch.Size([2, 1, 3])
temp2 = x.unsqueeze(-2) # torch.Size([2, 1, 3])
temp3 = x.view((2, 1, 3)) # torch.Size([2, 1, 3])
print('temp1: {}\ntemp2: {}\n temp3: {}'.format(temp1, temp2, temp3))
输出:
temp1: tensor([[[0, 1, 2]],
[[3, 4, 5]]])
temp2: tensor([[[0, 1, 2]],
[[3, 4, 5]]])
temp3: tensor([[[0, 1, 2]],
[[3, 4, 5]]])
可以看到与使用 view() 函数效果一致。