torch.squeeze()
以及torch.unsqueeze()
函数分别是给目标tensor去掉维度只有1的那一维或者给目标的某一维度添加一维,对应的有两个in-place
操作Tensor.squeeze_()
和Tensor.unsqueeze_()
一、torch.squeeze()
torch.squeeze(input, dim=None)
- input ([Tensor] – the input tensor.
- dim ([int], optional) – if given, the input will be squeezed only in this dimension
将输入的tensor中维度为1的那一维去除
Returns a tensor with all the dimensions of
input
of size 1 removed.For example, if input is of shape: ( A × 1 × B × C × 1 × D A \times 1 \times B \times C \times 1 \times D A×1×B×C×1×D), then the out tensor will be of shape: ( A × B × C × D A \times B \times C \times D A×B×C×D).
When
dim
is given, a squeeze operation is done only in the given dimension. If input is of shape: ( A × 1 × B A \times 1 \times B A×1×B),squeeze(input, 0)
leaves the tensor unchanged, butsqueeze(input, 1)
will squeeze the tensor to the shape ( A × B A \times B A×B).
warning:If the tensor has a batch dimension of size 1, then squeeze(input) will also remove the batch dimension, which can lead to unexpected errors.
x = torch.zeros(2, 1, 2, 1, 2)
print(x.size())
'''
torch.Size([2, 1, 2, 1, 2])
'''
y = torch.squeeze(x)
print("torch.squeeze(x): ", y.size())
y = torch.squeeze(x, 0)
print("torch.squeeze(x, 0): ", y.size())
y = torch.squeeze(x, 1)
print("torch.squeeze(x, 1): ", y.size())
'''
torch.squeeze(x): torch.Size([2, 2, 2])
torch.squeeze(x, 0): torch.Size([2, 1, 2, 1, 2])
torch.squeeze(x, 1): torch.Size([2, 2, 1, 2])
'''
''' in-place version '''
# modify x in origin storage
print(x.size())
x.squeeze_()
print(x.size())
'''
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
'''
二、torch.unsqueeze()
torch.unsqueeze(input, dim=None)
- input ([Tensor] – the input tensor.
- dim ([int], optional) – the index at which to insert the singleton dimension
将tensor的某一维增加1
Returns a new tensor with a dimension of size one inserted at the specified position.
The returned tensor shares the same underlying data with this tensor.
A
dim
value within the range[-input.dim() - 1, input.dim() + 1)
can be used. Negativedim
will correspond to [unsqueeze()
] applied atdim
=dim + input.dim() + 1
.
维度:dim的取值范围为 [ − i n p u t . d i m ( ) − 1 , i n p u t . d i m ( ) + 1 ) [-input.dim()-1, ~~input.dim()+1) [−input.dim()−1, input.dim()+1),负的维度会被映射到正的维度上,即if dim<0: dim = dim+input.dim()+1
举个例子,加入输入为二维:(3, 3)
dim的范围为[-3, 3),取值为-3, -2, -1, 0, 1, 2
dim = -3, 0时,输出:(1, 3, 3)
dim = -2, 1时,输出:(3, 1, 3)
dim = -1, 2时,输出:(3, 3, 1)
x = torch.tensor([1, 2, 3, 4])
print("x: ", x.size())
y1 = torch.unsqueeze(x, 0)
print("torch.squeeze(x, 0): ", y1.size())
y2 = torch.unsqueeze(x, 1)
print("torch.squeeze(x, 1): ", y2.size())
'''
x: torch.Size([4])
torch.squeeze(x, 0): torch.Size([1, 4])
torch.squeeze(x, 1): torch.Size([4, 1])
'''
print(x)
print(y1)
print(y2)
'''
tensor([1, 2, 3, 4])
tensor([[1, 2, 3, 4]])
tensor([[1],
[2],
[3],
[4]])
'''
''' in-place version '''
# modify x in origin storage
print(x.size())
x.unsqueeze_(0)
print(x.size())
'''
torch.Size([3, 3])
torch.Size([1, 3, 3])
'''