1. 定义
torch.squeeze函数的作用是对输入的张量进行处理,如果张量维度里面有大小为1 的部分,那我们就移除,否则保留
torch.squeeze(input, dim=None, *, out=None) → Tensor
- input : 输入的张量
- dim : 默认保留,可以指定维度
2. 代码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: test_torch_squeeze
# @Create time: 2021/12/25 16:17
# 导入相关数据库
import torch
# 定义一个张量,维度为 size = [2,1,2,1,2]
# torch.sequeeze :表示移除张量中,所有大小为1的维度
# dim :指定此维度,如果大小为 1 ,则移除,否则保持
x = torch.zeros(2, 1, 2, 1, 2)
x_squeeze_0 = x.squeeze(0) # 第 dim=0 维度是不是1,则保持;size = [2,1,2,1,2]
x_squeeze_1 = x.squeeze(1) # 第 dim=1 维度是1,则移除; size = [2,2,1,2]
x_squeeze_all = torch.squeeze(x) # 移除所有size=1 的维度,size = [2,2,2]
print(f'x_shape={x.shape}')
print(f'x_squeeze_0={x_squeeze_0.shape}')
print(f'x_squeeze_1={x_squeeze_1.shape}')
print(f'x_squeeze_all={x_squeeze_all.shape}')
- 结果
x_shape=torch.Size([2, 1, 2, 1, 2])
x_squeeze_0=torch.Size([2, 1, 2, 1, 2])
x_squeeze_1=torch.Size([2, 2, 1, 2])
x_squeeze_all=torch.Size([2, 2, 2])