pytorch中squeeze()和unsqueeze()函数意义

1 unsqueeze()

函数功能:与squeeze()函数功能相反,用于添加维度;

>>> a = torch.Tensor(3)
>>> a
tensor([1.7718e+28, 1.0509e-38, 0.0000e+00])
>>> a.unsqueeze(0) # 扩展第一个维度
tensor([[1.7718e+28, 1.0509e-38, 0.0000e+00]])
>>> a.unsqueeze(1) #扩展第二个维度
tensor([[1.7718e+28],
        [1.0509e-38],
        [0.0000e+00]])

2 squeeze()

        squeeze本身有挤压的意思;

函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用;

        其中squeeze(0)代表若第一维度值为1则去除第一维度;

        squeeze(1)代表若第二维度值为1则去除第二维度;

>>> a = torch.Tensor(3,2)
>>> a
tensor([[-6.5850e+34,  4.5759e-41],
        [-6.5850e+34,  4.5759e-41],
        [ 0.0000e+00,  0.0000e+00]])
>>> a.unsqueeze(0)
tensor([[[-6.5850e+34,  4.5759e-41],
         [-6.5850e+34,  4.5759e-41],
         [ 0.0000e+00,  0.0000e+00]]])
>>> a.squeeze(0) # 第一维度会被缩减
tensor([[-6.5850e+34,  4.5759e-41],
        [-6.5850e+34,  4.5759e-41],
        [ 0.0000e+00,  0.0000e+00]])
>>> a.squeeze(1)# 第二维度不会被缩减
tensor([[-6.5850e+34,  4.5759e-41],
        [-6.5850e+34,  4.5759e-41],
        [ 0.0000e+00,  0.0000e+00]])
>>> a.unsqueeze(-1) # 表示最后一个维度
tensor([[[-6.5850e+34],
         [ 4.5759e-41]],

        [[-6.5850e+34],
         [ 4.5759e-41]],

        [[ 0.0000e+00],
         [ 0.0000e+00]]])
>>> a.squeeze(-1)
tensor([[-6.5850e+34,  4.5759e-41],
        [-6.5850e+34,  4.5759e-41],
        [ 0.0000e+00,  0.0000e+00]])

参考: 【学习笔记】pytorch中squeeze()和unsqueeze()函数介绍_Jaborie203的博客-CSDN博客_pytorch unsqueeze

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值