神经网络学习小记录33——pytorch中squeeze()和unsqueeze()函数的简单介绍

神经网络学习小记录33——pytorch中squeeze和unsqueeze函数的简单介绍

学习前言

经常看到在tf中看到squeeze,学会pytorch,结果刚入门就发现了这个函数,我决定弄懂它,顺便写篇文章水一下。
在这里插入图片描述

1、unsqueeze

其实unsqueeze的作用和np.expand_dims的作用非常类似,都是为矩阵增加一个维度,unsqueeze是为了pytorch中的tensor增加一个维度。

函数声明为:

torch.unsqueeze(dim)

其中dim表示需要在哪一维增加一个维度,dim必须被指定。

试验示例:

import torch
before_unsqueeze = torch.arange(12).reshape([3,4])
print(before_unsqueeze.data)
after_unsqueeze = before_unsqueeze.unsqueeze(1)
print(after_unsqueeze.data)

结果:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]]])

2、squeeze

其实squeeze的作用和tf.squeeze的作用非常类似,二者都是将被操作目标中维度为1的部分去除。

函数声明为:

torch.squeeze(dim=None)

其中dim表示需要在哪一维去掉一个维度,如果不指定则自动寻找,如果指定则当指定的维度为1时去掉,如果不为1则不改变。

试验示例:

import torch
before_squeeze = torch.arange(12).reshape([1,3,4])
print(before_squeeze.data)
# 指定维度
after_squeeze = before_squeeze.squeeze(1)
print(after_squeeze.data)
# 自动去除
after_squeeze = before_squeeze.squeeze()
print(after_squeeze.data)

结果:

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]])
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Bubbliiiing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值