你真的理解PyTorch中的dimension嘛?

10 篇文章 0 订阅
5 篇文章 0 订阅
文章通过示例解释了PyTorch中张量的维度概念,特别是如何在不同维度上进行求和操作。对于1D、2D和3D张量,分别展示了sum函数如何沿着指定维度折叠数据。文章指出,第一个维度通常对应最高级别的结构,例如在3D张量中表示批次。同时,对比了PyTorch的dim参数与Numpy的axis参数的相似性。
摘要由CSDN通过智能技术生成

Understanding dimensions in PyTorch

通过可视化3D张量上的求和过程,为PyTorch维度提供更好的直觉

当我们开始用PyTorch张量做一些基本的运算,比如求和时,对于一维张量来说,这看起来很容易,也很简单:

>> x = torch.tensor([1, 2, 3])
>> torch.sum(x)
tensor(6)

让我们从官方文件中的内容开始:

torch.sum(input, dim, keepdim=False, dtype=None) → Tensor

Returns the sum of each row of the input tensor in the given dimension dim.

当我们描述二维张量的形状时,我们说它包含一些行和一些列。因此,对于2x3张量,我们有2行3列:

>> x = torch.tensor([
     [1, 2, 3],
     [4, 5, 6]
   ])
>> x.shape
torch.Size([2, 3])

我们首先指定行(2行),然后指定列(3列)。我们可以得出一个结论,第一个维度(dim=0)用于行,第二个维度(dim=1)用于列。根据维度dim=0意味着行的推理,torch.sum(x,dim=0)会产生1x2张量(tensor[6,15]的结果为1+2+34+5+6)。但事实证明我们得到了不同的东西:一个1x3张量。

>> torch.sum(x, dim=0)
tensor([5, 7, 9])

当传递参数dim=1时,我们最终得到的结果是tensor[6,15]

>> torch.sum(x, dim=1)
tensor([6, 15])

在Numpy中的sum()方法中,我们需要传递的第二个参数是axis。Numpy中的sum()方法和PyTorch中的sum()方法几乎相同,除了PyTorch中的dim在Numpy中被称作axis之外

numpy.sum(a, axis=None, dtype=None, out=None, keepdims=False)

下面这句话是我们理解PyTorch中的dim和Numpy中的axis的关键:

Numpy sum()方法中的axis被用于折叠指定的axis,当axis=0时,它会折叠Numpy数据data的行,此时data只有一行的数据(也就是说其对data按列进行了求和)

然而,当我们引入第三维时,它就变得更棘手了。当我们观察3D张量的形状时,我们会注意到新的维度被预处理并占据第一个位置(下面用粗体显示),即dim=0代表第三个维度

>> y = torch.tensor([
     [
       [1, 2, 3],
       [4, 5, 6]
     ],
     [
       [1, 2, 3],
       [4, 5, 6]
     ],
     [
       [1, 2, 3],
       [4, 5, 6]
     ]
   ])
>> y.shape
torch.Size([3, 2, 3])

这个三维张量的第一个维度(dim=0)是最高的,包含3个二维张量。因此,为了求和它,我们必须将它的3个元素折叠在一起:

>> torch.sum(y, dim=0)
tensor([[ 3,  6,  9],
        [12, 15, 18]])

在这里插入图片描述

对于第二个维度(dim=1),我们必须折叠行:

>> torch.sum(y, dim=1)
tensor([[5, 7, 9],
        [5, 7, 9],
        [5, 7, 9]])

在这里插入图片描述

最后,第三个维度折叠在列上:

>> torch.sum(y, dim=2)
tensor([[ 6, 15],
        [ 6, 15],
        [ 6, 15]])

在这里插入图片描述

参考

1、Understanding dimensions in PyTorch:https://towardsdatascience.com/understanding-dimensions-in-pytorch-6edf9972d3be
2、Numpy Sum Axis Intuition:https://medium.com/intuitionmath/numpy-sum-axis-intuition-6eb94926a5d1

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值