06-27 周四 理解torch.squeeze和unsqueeze

简介

最近在看动手学深度学习这本书的注意力机制章节,理解起来很吃力,主要原因是一些底层api的具体执行不理解。

代码

unsqueeze

在这里插入图片描述

x=torch.ones((2, 1, 4))
y = torch.ones((2,4,6))
z = torch.bmm(x, y)
print(f"z.shpae: ", z.shape)


weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape(2, 10)
print(f"weights: {weights}, values: {values}")
print(f"weights.shape: {weights.shape}, values.shape: {values.shape}")
t1 = weights.unsqueeze(1)
print(f"t1.shape: {t1.shape}, t1: {t1}")
print(f"values.dim: {values.dim()}")
t2 = values.unsqueeze(-1)
t3 = values.unsqueeze(2)
print(f"t2.shape: {t2.shape}, t2: {t2}")
print(f"t3.shape: {t3.shape}, t3: {t3}")
print(f"t2.dim: {t2.dim()}")

print(torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1)))

输出

z.shpae:  torch.Size([2, 1, 6])
weights: tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000]]), values: tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]])
weights.shape: torch.Size([2, 10]), values.shape: torch.Size([2, 10])
t1.shape: torch.Size([2, 1, 10]), t1: tensor([[[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000, 0.1000]],

        [[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000, 0.1000]]])
values.dim: 2
t2.shape: torch.Size([2, 10, 1]), t2: tensor([[[ 0.],
         [ 1.],
         [ 2.],
         [ 3.],
         [ 4.],
         [ 5.],
         [ 6.],
         [ 7.],
         [ 8.],
         [ 9.]],

        [[10.],
         [11.],
         [12.],
         [13.],
         [14.],
         [15.],
         [16.],
         [17.],
         [18.],
         [19.]]])
t3.shape: torch.Size([2, 10, 1]), t3: tensor([[[ 0.],
         [ 1.],
         [ 2.],
         [ 3.],
         [ 4.],
         [ 5.],
         [ 6.],
         [ 7.],
         [ 8.],
         [ 9.]],

        [[10.],
         [11.],
         [12.],
         [13.],
         [14.],
         [15.],
         [16.],
         [17.],
         [18.],
         [19.]]])
t2.dim: 3
tensor([[[ 4.5000]],

        [[14.5000]]])

解释
在这里插入图片描述

在 PyTorch 中,对张量进行unsqueeze操作时,dim的取值范围是(-input.dim() - 1, input.dim() + 1)(左闭右开)。
当dim为负数时,表示从后向前计数(即,dim == -1和dim == input.dim()等效)。负的dim值会被转化为dim + input.dim() + 1。

import torch

# 示例 1
a = torch.tensor([1, 2, 3, 4])  
print(a.shape)  # torch.Size([4]) 

b = torch.unsqueeze(a, 0)  
print(b.shape)  # torch.Size([1, 4])  # 在第 0 维(最前面)插入一维

c = torch.unsqueeze(a, 1)  
print(c.shape)  # torch.Size([4, 1])  # 在第 1 维(中间)插入一维

# 示例 2
d = torch.randn(2, 3)  
print(d.shape)  # torch.Size([2, 3]) 

e = torch.unsqueeze(d, 2)  
print(e.shape)  # torch.Size([2, 3, 1])  # 在第 2 维(最后面)插入一维

squeeze

在 PyTorch 中,torch.squeeze() 方法用于对张量的维度进行压缩,即去掉维数为 1 的维度。
其函数原型为:

torch.squeeze(input, dim=None, *, out=None) → Tensor

具体用法如下:
torch.squeeze(a):去掉张量 a 中所有维数为 1 的维度。
a.squeeze(N) 或 torch.squeeze(a, N):去掉张量 a 中指定的维数为 1 的维度 N。
需要注意的是,如果要删除的维度大小不是 1,则 squeeze() 方法不会删除该维度。

import torch

# 示例 1
a = torch.randn(1, 3, 1, 4)
print(a.shape)  # torch.Size([1, 3, 1, 4]) 

b = torch.squeeze(a)
print(b.shape)  # torch.Size([3, 4])  # 去掉所有维数为 1 的维度

# 示例 2
c = torch.randn(2, 1, 2, 1, 2)
print(c.shape)  # torch.Size([2, 1, 2, 1, 2]) 

d = torch.squeeze(c, 1) 
print(d.shape)  # torch.Size([2, 2, 1, 2])  # 去掉维度为 1 的第二维

总结

机器学习的库,欠缺的比较多,还需要不断的积累才行。太痛苦了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值