pytorch中squeeze()和unsqueeze()函数介绍

⭐能让人成长的,从来不是停留在舒适区,再迈一步,再试一次,你总能发现一个更强大的自己。

文章目录

先来看看官方的Docs,链接在这里:torch.squeeze — PyTorch master documentation

 这个函数就是返回一个张量,将input中大小为1的维度都删除。例如:假设一个输入的shape为(AX1XBXCX1XD),则其output的shape为:(AXBXCXD)。若给定维度,则删除给定的维度,但是只有大小为1的维度才会被删除。下面举一个例子:

import torch
x = torch.zeros(1,2,1,2,3)
print(f"x:{x.shape}")
y = torch.squeeze(x) # x中大小为1的维度都删除
print(f"y:{y.shape}")
z = torch.squeeze(x,0) # 删除第一个维度
print(f"z:{z.shape}")
w = torch.squeeze(x,-3) # -1是指倒数第一个维度,-2是指倒数第二个,依次类推
print(f"w:{w.shape}")
g = torch.squeeze(x,1) # 只有大小为1的维度才可以被删除
print(f"g:{g.shape}")

结果如下:

x:torch.Size([1, 2, 1, 2, 3])
y:torch.Size([2, 2, 3])
z:torch.Size([2, 1, 2, 3])
w:torch.Size([1, 2, 2, 3])
g:torch.Size([1, 2, 1, 2, 3])

二、unsqueeze()

官方Docs如下,链接在这里:torch.unsqueeze — PyTorch master documentation

unsqueeze()就是给指定位置加上维数为一的维度。返回的张量与该张量共享相同的基础数据。例子如下:

import torch
x = torch.zeros(1,2,1,2,3)
print(f"x:{x.shape}")
y = torch.unsqueeze(x,1) # 在第二维增加一个维度,该维度的大小为1
print(f"y:{y.shape}")

结果如下:

x:torch.Size([1, 2, 1, 2, 3])
y:torch.Size([1, 1, 2, 1, 2, 3])

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值