torch.(un)squeeze

torch.squeeze()以及torch.unsqueeze()函数分别是给目标tensor去掉维度只有1的那一维或者给目标的某一维度添加一维,对应的有两个in-place操作Tensor.squeeze_()Tensor.unsqueeze_()

一、torch.squeeze()

torch.squeeze(input, dim=None)

  • input ([Tensor] – the input tensor.
  • dim ([int], optional) – if given, the input will be squeezed only in this dimension

将输入的tensor中维度为1的那一维去除

Returns a tensor with all the dimensions of input of size 1 removed.

For example, if input is of shape: ( A × 1 × B × C × 1 × D A \times 1 \times B \times C \times 1 \times D A×1×B×C×1×D), then the out tensor will be of shape: ( A × B × C × D A \times B \times C \times D A×B×C×D).

When dim is given, a squeeze operation is done only in the given dimension. If input is of shape: ( A × 1 × B A \times 1 \times B A×1×B), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape ( A × B A \times B A×B).

warning:If the tensor has a batch dimension of size 1, then squeeze(input) will also remove the batch dimension, which can lead to unexpected errors.

x = torch.zeros(2, 1, 2, 1, 2)
print(x.size())
'''
torch.Size([2, 1, 2, 1, 2])
'''


y = torch.squeeze(x)
print("torch.squeeze(x):     ", y.size())
y = torch.squeeze(x, 0)
print("torch.squeeze(x, 0):  ", y.size())
y = torch.squeeze(x, 1)
print("torch.squeeze(x, 1):  ", y.size())
'''
torch.squeeze(x):      torch.Size([2, 2, 2])
torch.squeeze(x, 0):   torch.Size([2, 1, 2, 1, 2])
torch.squeeze(x, 1):   torch.Size([2, 2, 1, 2])
'''


''' in-place version '''
# modify x in origin storage
print(x.size())
x.squeeze_()
print(x.size())
'''
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
'''

二、torch.unsqueeze()

torch.unsqueeze(input, dim=None)

  • input ([Tensor] – the input tensor.
  • dim ([int], optional) – the index at which to insert the singleton dimension

将tensor的某一维增加1

Returns a new tensor with a dimension of size one inserted at the specified position.

The returned tensor shares the same underlying data with this tensor.

A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to [unsqueeze()] applied at dim = dim + input.dim() + 1.

维度:dim的取值范围为 [ − i n p u t . d i m ( ) − 1 ,    i n p u t . d i m ( ) + 1 ) [-input.dim()-1, ~~input.dim()+1) [input.dim()1,  input.dim()+1),负的维度会被映射到正的维度上,即if dim<0: dim = dim+input.dim()+1

举个例子,加入输入为二维:(3, 3)

dim的范围为[-3, 3),取值为-3, -2, -1, 0, 1, 2

dim = -3, 0时,输出:(1, 3, 3)

dim = -2, 1时,输出:(3, 1, 3)

dim = -1, 2时,输出:(3, 3, 1)

x = torch.tensor([1, 2, 3, 4])
print("x:                    ", x.size())
y1 = torch.unsqueeze(x, 0)
print("torch.squeeze(x, 0):  ", y1.size())
y2 = torch.unsqueeze(x, 1)
print("torch.squeeze(x, 1):  ", y2.size())
'''
x:                     torch.Size([4])
torch.squeeze(x, 0):   torch.Size([1, 4])
torch.squeeze(x, 1):   torch.Size([4, 1])
'''


print(x)
print(y1)
print(y2)
'''
tensor([1, 2, 3, 4])
tensor([[1, 2, 3, 4]])
tensor([[1],
        [2],
        [3],
        [4]])
'''


''' in-place version '''
# modify x in origin storage
print(x.size())
x.unsqueeze_(0)
print(x.size())
'''
torch.Size([3, 3])
torch.Size([1, 3, 3])
'''

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值