pytorch 中 squeeze 和unsqueeze函数

1. torch.squeeze() 函数 :

作用:移除指定或所有维数为1的维度,从而得到维度减少的张量

解释一下:

x=torch.zeros(5,1,1,1)

print(x)

'输出'
tensor([[[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]]])

举个极端点的例子,这是一个4维的数组,除了第0个维度之外每个维度的维数均为1

也就是说,每一个0都被3个括号括着,这显然不太合理

下面调用squeeze函数:

y = x.squeeze()
print(y)
print(y.shape)

'输出'
tensor([0., 0., 0., 0., 0.])
torch.Size([5])

瞬间一系列的括号都没有了,是不是看着舒服了许多?

进一步:

y = x.squeeze(1)
print(y)
print(y.shape)

'输出'
tensor([[[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]]])
torch.Size([5, 1, 1])

这里添加了参数1,这样就只压缩了第1个维度(计数从0开始),一个0被2个括号括着

但压缩的前提是,该张量必须有维数为1的维度,比如:

y = x.squeeze(0)
print(y)
print(y.shape)

a = torch.tensor([[1, 1, 1], [2, 2, 2]])
b = a.squeeze()
print(b)
print(b.shape)

'输出'
tensor([[[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]]])
torch.Size([5, 1, 1, 1])

tensor([[1, 1, 1],
        [2, 2, 2]])
torch.Size([2, 3])

y和b相对于x和a均没有发生变化,原因就是:x的第0个维度,维数不是1;a中更是没有维数为1的维度

另外:x.squeeze() 或者 torch.squeeze(x) 都不会让x发生改变

y = x.squeeze()
print(x.shape)
print(y.shape)


'输出'
torch.Size([5, 1, 1, 1])
torch.Size([5])

2. torch.unsqueeze() 函数 :

作用:在张量的制定维度插入新的维度得到维度提升的张量

举个例子:

 x= torch.zeros(5)
print(x)
print(x.shape)

'输出'
tensor([0, 0, 0, 0, 0])
torch.Size([5])

一维张量,总共5个0,接下来依次操作:

y = x.unsqueeze(dim=0)
print(y)
print(y.shape)

y = x.unsqueeze(dim=1)
print(y)
print(y.shape)

z = y.unsqueeze(dim=2)
print(z)
print(z.shape)

'输出'
tensor([[0., 0., 0., 0., 0.]])
torch.Size([1, 5])

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.]])
torch.Size([5, 1])

tensor([[[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]]])
torch.Size([5, 1, 1])

把第0维进行扩张,就是在最外面加了一个括号

把第1维进行扩张,就是把里面的每个元素元素(也可以理解成是,扩充后的第1维,也就是0.)都加一个括号

继续套,选择dim=2,还是把最内层的了(也可以理解成是,扩充后的第2维,也就是0.),都加一个括号

再举个例子:

a = torch.tensor([[1, 1, 1], [2, 2, 2]])
print(a)
print(a.shape)

b = a.unsqueeze(dim=0)
print(b)
print(b.shape)

b = a.unsqueeze(dim=1)
print(b)
print(b.shape)

b = a.unsqueeze(dim=2)
print(b)
print(b.shape)

'输出'
tensor([[1, 1, 1],
        [2, 2, 2]])
torch.Size([2, 3])

tensor([[[1, 1, 1],
         [2, 2, 2]]])
torch.Size([1, 2, 3])

tensor([[[1, 1, 1]],
        [[2, 2, 2]]])
torch.Size([2, 1, 3])

tensor([[[1],
         [1],
         [1]],
        [[2],
         [2],
         [2]]])
torch.Size([2, 3, 1])

怎么套的括号,是不是一目了然~

同样:x.unsqueeze() 或者 torch.unsqueeze(x) 都不会让x发生改变

x = torch.tensor([[1, 1, 1], [2, 2, 2]])
print(x)
y = x.unsqueeze(dim=0)
print(y.shape)
print(x.shape)


'输出'
torch.Size([1, 2, 3])
torch.Size([2, 3])

  • 10
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dylan_zhang7788

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值