[Pytorch学习笔记] 简单记录一下squeeze() 和 unsqueeze()

1. squeeze()

函数形式如下:

torch.squeeze(input, dim=None) → Tensor

作用是对tensor变量进行降维,也就是维度压缩。但是有前提条件,就是被压缩的tensor中存在大小为1的维度,也就是他只能压缩冗余的维度,并保持数据内容不变
如果不是指定参数中的dim,也就是维度,那么函数默认将压缩所有大小为1的维度,如:

import torch
a = torch.randn(2, 1, 2, 1, 2)
b = a.squeeze()   # 也可以写成b = torch.squeeze(a) (官方文档写法)
print(b.shape)   #  输出 torch.Size([2, 2, 2])

如果传入dim参数,那么squeeze()只对指定维度进行压缩(前提是大小为1)。

import torch
a = torch.randn(5, 1, 4)
print(a.squeeze(1).shape)  # 输出 torch.Size([5, 4]),压缩了第1维
b =  torch.randn(5, 4)
print(b.squeeze(1).shape)  # 输出 torch.Size([5, 4]),没有压缩

squeeze官方文档

2. unsqueeze()

函数形式如下:

torch.unsqueeze(input, dim) → Tensor

作用是对tensor变量,在参数dim指定的维度上进行扩充。

import torch
a = torch.randn(5, 4)
b = a.unsqueeze(0)  # 也可以写成b = torch.unsqueeze(a, 0)(官方文档写法)
print(b.shape) # 输出torch.Size([1, 5, 4])

unsqueeze官方文档

在一些模型中进行tensor之间的运算时,进行维度扩充很有用。当然在pytorch中存在广播机制(Broadcast),如果某个方法能够Broadcast,那么方法会将参数中的tensor自动进行维度变化以满足运算要求,例如torch.matmul().

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值