神经网络学习小记录33——pytorch中squeeze和unsqueeze函数的简单介绍
学习前言
经常看到在tf中看到squeeze,学会pytorch,结果刚入门就发现了这个函数,我决定弄懂它,顺便写篇文章水一下。
1、unsqueeze
其实unsqueeze的作用和np.expand_dims的作用非常类似,都是为矩阵增加一个维度,unsqueeze是为了pytorch中的tensor增加一个维度。
函数声明为:
torch.unsqueeze(dim)
其中dim表示需要在哪一维增加一个维度,dim必须被指定。
试验示例:
import torch
before_unsqueeze = torch.arange(12).reshape([3,4])
print(before_unsqueeze.data)
after_unsqueeze = before_unsqueeze.unsqueeze(1)
print(after_unsqueeze.data)
结果:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[[ 0, 1, 2, 3]],
[[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11]]])
2、squeeze
其实squeeze的作用和tf.squeeze的作用非常类似,二者都是将被操作目标中维度为1的部分去除。
函数声明为:
torch.squeeze(dim=None)
其中dim表示需要在哪一维去掉一个维度,如果不指定则自动寻找,如果指定则当指定的维度为1时去掉,如果不为1则不改变。
试验示例:
import torch
before_squeeze = torch.arange(12).reshape([1,3,4])
print(before_squeeze.data)
# 指定维度
after_squeeze = before_squeeze.squeeze(1)
print(after_squeeze.data)
# 自动去除
after_squeeze = before_squeeze.squeeze()
print(after_squeeze.data)
结果:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]])
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])