PyTorch中维度的概念(转)

转载自:

https://mathpretty.com/12065.html

 

简介

今天在使用torch中的topk的时候, 对于dim产生了一些疑问. 后面也是找到了规律, 但是还是很困惑他为什么是这么设计的, 即dim与tensor本身的行列是不一致的. 然后就查了一下, 真的找到了一篇很好的文章, 解决了我的困惑, 就想在这里记录一下.

我这一篇文章里的所有的动图, 都是来自与下面这篇文章, 写的非常直观.

原文链接(十分棒的文章), Understanding dimensions in PyTorch

关于这里设计的代码, 有一个完整的notebook的文档, 具体链接见GithubPytorch维度介绍.ipynb

理解PyTorch维度概念

首先我们从最基础的开始, 当我们在Pytorch中定义一个二维的tensor的时候, 他包含行和列. 例如下面我们创建一个2*3的tensor.

  1. x = torch.tensor([
  2.         [1,2,3],
  3.         [4,5,6]
  4.     ])
  5. # 我们可以看到"行"是dim=0, "列"是dim=1
  6. print(x.shape)
  7. >> torch.Size([2, 3])

我们可以看到打印的结果显示:

  • first dimension (dim=0) stays for rows, 第一个维度代表行, 因为是2, 实际x就是2行
  • the second one (dim=1) for columns, 第二个维度代表列, 因为是3

于是, 我们会认为, torch.sum(x, dim=0)就是(1+2+3, 4+5+6)=tensor([6, 15]), 但是实际情况却不是这个样子的.

  1. torch.sum(x, dim=0)
  2. >> tensor([5, 7, 9])

我们可以看到按照dim=0求和, 其实是在按列相加, 也就是(1+4, 2+5, 3+6) =tensor([5, 7, 9]), 和我们想象的完全不一样. 我们再看一下按照dim=1进行求和.

  1. torch.sum(x, dim=1)
  2. >> tensor([ 6, 15])

可以看到, 在按照dim=1的时候求和的时候, 其实在按照按行进行求和,  (1+2+3, 4+5+6)=tensor([6, 15]), 这就让人很困惑, 明明上面说的是dim=0代表是行.

于是, 原文作者在一篇介绍numpy维度的文章中, 找到了问题的关键所在. 也就是下面的这段话(numpy中的axis也就是这里的dim).

The way to understand the "axis" of numpy sum is that it collapses the specified axis. So when it collapses the axis 0 (the row), it becomes just one row (it sums column-wise).

上面的话简单翻译就是, 当按照axis=0进行求和的时候, 其实可以想象为对axis=0这个维度进行挤压, 最后只剩下一行, 那一行就是结果, 也就是按列在相加.

是不是还是会有一些困惑, 我们还是对于上面的例子(tensor([[1,2,3], [4,5,6]])), 看一下在dim=0的时候, 为什么是列相加, 以及上面的collapse the specific axis(dim)的含义.

理解PyTorch中维度的概念

如上面的动图所示, 当dim=0的时候, 按每一行的元素进行相加, 最后的结果就是和按列求和.

对于三维向量

下面我们更进一步, 来看一下对于三维的tensor, 在各个维度进行sum操作的结果. 首先我们看一下每一个dim代表的含义.

  1. # 看一下三维的
  2. x = torch.tensor([
  3.         [
  4.          [1,2,3],
  5.          [4,5,6]
  6.         ],
  7.         [
  8.          [1,2,3],
  9.          [4,5,6]
  10.         ],
  11.         [
  12.          [1,2,3],
  13.          [4,5,6]
  14.         ]
  15.     ])
  16. # 我们可以看到第三维是dim=0, "行"是dim=1, 列是dim=2
  17. print(x.shape)
  18. >> torch.Size([3, 2, 3])

可以看到此时dim=0是第三个维度, dim=1是行, dim=2是列.

  1. torch.sum(x, dim=0)
  2. >>
  3. tensor([[ 3,  6,  9],
  4.         [12, 15, 18]])

我们可以将其看成是各个二维平面对应元素求和, 还是有点绕, 还是直接看下面的动图.

理解PyTorch中维度的概念

接着是对dim=1进行求和.

  1. torch.sum(x, dim=1)
  2. >>
  3. tensor([[5, 7, 9],
  4.         [5, 7, 9],
  5.         [5, 7, 9]])

还是直接看下面的动图, 来进行理解.

理解PyTorch中维度的概念

最后按照dim=2来进行求和.

  1. torch.sum(x, dim=2)
  2. >>
  3. tensor([[ 6, 15],
  4.         [ 6, 15],
  5.         [ 6, 15]])

还是使用动图来进行解释.

理解PyTorch中维度的概念

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值