对PyTorch的dim的理解


前言

  做深度学习的项目离不开对tensor的操作,tensor中文名称是张量,以PyTorch框架为例,张量是PyTorch的基本数据类型,初学者对张量操作时,常常会被dim这个参数困扰,本文测试了torch.max()、torch.argmax()、torch.softmax()、torch.stack()四个函数的dim值,以利于初学者对dim和张量的理解。
  本文基于PyTorch框架测试,在PyTorch中,从数据维度角度看,可分为:
  scalar(标量):一个数值;
  vector(向量):一维数组;
  matrix(矩阵):二维数组;
  tensor(张量):大于二维的数组。
  但是有时候我们也叫标量为零维张量、向量为一维张量等等,这样的叫法其实区分并不明显。

多维张量的维度

  使用如下代码生成3 * 2 * 2的三维张量:

import torch
x=torch.arange(12).reshape((3,2,2))
print(x)
for i in range(3):
    for j in range(2):
        for k in range(2):
            print("x[{}{}{}] is {}".format(i,j,k,x[i][j][k]))

  结果如下:

tensor([[[ 0,  1],
         [ 2,  3]],
        [[ 4,  5],
         [ 6,  7]],
        [[ 8,  9],
         [10, 11]]])
x[000] is 0
x[001] is 1
x[010] is 2
x[011] is 3
x[100] is 4
x[101] is 5
x[110] is 6
x[111] is 7
x[200] is 8
x[201] is 9
x[210] is 10
x[211] is 11

  首先观察生成的矩阵,我们会发现,三维张量最前面会有三个“[[[”,同理可以实验验证,几维张量就会有几个“[”。
  观察x[0][0][0]、x[1][0][0]、x[2][0][0],值分别为0、4、8,会发现当第0个维度(人们日常习惯从1开始计数,计算机习惯从0开始计数)改变时,跨越的是两个“[”,也就是第0维控制的是最外层;
  观察x[0][0][0]、x[0][1][0],值分别为0、2,此时第一个维度改变了,会发现在第0维的控制范围内,改变的值跨越一个‘[’,也就是说第一维控制的是从外向里数的第二层,再观察x[1][0][0]、x[1][1][0],可以验证我们的想法;
  观察x[0][0][0]、x[0][0][1],值分别为0、1,此时第二个维度改变了,也就是最后一个维度改变了,发现他改变的值没有跨越“[”。
  接着使用如下代码生成一个四维张量:

import torch
x=torch.arange(24).reshape((3,2,2,2))
print(x)
for i in range(3):
    for j in range(2):
        for k in range(2):
            for z in range(2):
                print("x[{}{}{}{}] is {}".format(i,j,k,z,x[i][j][k][z]))

  结果如下:

tensor([[[[ 0,  1],
          [ 2,  3]],
         [[ 4,  5],
          [ 6,  7]]],
        [[[ 8,  9],
          [10, 11]],
         [[12, 13],
          [14, 15]]],
        [[[16, 17],
          [18, 19]],
         [[20, 21],
          [22, 23]]]])
x[0000] is 0
x[0001] is 1
x[0010] is 2
x[0011] is 3
x[0100] is 4
x[0101] is 5
x[0110] is 6
x[0111] is 7
x[1000] is 8
x[1001] is 9
x[1010] is 10
x[1011] is 11
x[1100] is 12
x[1101] is 13
x[1110] is 14
x[1111] is 15
x[2000] is 16
x[2001] is 17
x[2010] is 18
x[2011] is 19
x[2100] is 20
x[2101] is 21
x[2110] is 22
x[2111] is 23

  观察x[0][0][0][0]、x[1][0][0][0、x[2][0][0][0],值分别为0、8、16,会发现当第0个维度改变时,跨越的是三个“[”,也就是第0维控制的是最外层;
  观察x[0][0][0][0]、x[0][1][0][0],值分别为0、4,此时第一个维度改变了,会发现在第0维的控制范围内,改变的值跨越两个‘[’,也就是说第一维控制的是从外向里数的第二层,再观察x[1][0][0][0]、x[1][1][0][0],可以验证我们的想法;
  观察x[0][0][0][0]、x[0][0][1][0],值分别为0、2,此时第二个维度改变了,发现它改变的值跨越一个‘[’,也就是说第一维控制的是从外向里数的第三层;
  观察x[0][0][0][0]、x[0][0][0][1],值分别为0、1,此时第三个维度改变了,也就是最后一个维度改变了,但是它改变的值没有跨越“[”。

  通过上述两个例子,可以发现,在n维张量中,如果改变第n-1维的值,则改变的值跨越0个’[’;改变第n-2维,跨越的维度是1个’[’;以此类推,改变第0维,改变的值跨越n-1个’[’。
  文笔不太好,希望通过上述例子,帮助大家认识张量维度。

torch.max()

  使用如下代码测试torch.max()函数,dim可取的值为0、1、2、3、-1(dim=3和dim=-1效果是一样的,即使用函数操作n维张量,函数参数中dim=n-1和dim=-1效果是一样的),测试当dim=0时的代码和运行结果如下:

import torch
x=torch.randint(1,100,(3,2,2))
print(x)
a,b= torch.max(x, dim=0)
print(a)
print(b)

  结果如下:

tensor([[[59,  1],
         [ 5, 75]],
        [[93, 12],
         [20, 99]],
        [[67, 31],
         [32, 34]]])
tensor([[93, 31],
        [32, 99]])
tensor([[1, 2],
        [2, 1]])

  torch.max()主要找出某个维度的最大值,当dim=0时,表示分别找出第0维对应位置的最大值,即比较x[0][0][0]、x[1][0][0]、x[2][0][0];x[0][0][1]、x[1][0][1]、x[2][0][1];x[0][1][0]、x[1][1][0]、x[2][1][0];x[0][1][1]、x[1][1][1]、x[2][1][1]。然后得到4个值,a表示每个位置上返回的最大值为多少,b表示该最大值在第0维上的位置。
  测试当dim=1时的代码如下:

import torch
x=torch.tensor([[[59,  1],
         [ 5, 75]],
        [[93, 12],
         [20, 99]],
        [[67, 31],
         [32, 34]]])
print(x)
a,b= torch.max(x, dim=1)
print(a)
print(b)

  结果如下:

tensor([[[59,  1],
         [ 5, 75]],
        [[93, 12],
         [20, 99]],
        [[67, 31],
         [32, 34]]])
tensor([[59, 75],
        [93, 99],
        [67, 34]])
tensor([[0, 1],
        [0, 1],
        [0, 1]])

  dim=1时,表示分别找出第1维对应位置的最大值,即比较x[0][0][0]、x[0][1][0];x[0][0][1]、x[0][1][1];等,最后得到6个值。
dim=2,dim=3原理一样。

torch.argmax()

  torch.argmax()也是求某个维度对应位置的最大值,不过torch.argmax()相比于torch.max()只返回最大值对应的位置,不返回最大值是多少,示例代码和运行结果如下:

import torch
x=torch.tensor([[[59,  1],
         [ 5, 75]],
        [[93, 12],
         [20, 99]],
        [[67, 31],
         [32, 34]]])
print(x)
a,b= torch.argmax(x, dim=0)
print(a)
print(b)

tensor([[[59,  1],
         [ 5, 75]],
        [[93, 12],
         [20, 99]],
        [[67, 31],
         [32, 34]]])
tensor([1, 2])
tensor([2, 1])

torch.softmax()

  torch.softmax函数用于将多个输出值转换为多个概率值,范围在[0,1],且概率相加和为1。
  若对一维张量进行softmax转换,就是对一维张量的几个数做softmax处理,不涉及到维度,把dim设置为0或-1即可,代码和运行结果如下:

import torch
x=torch.tensor([1,2,3]).float()
print(x)
y=torch.softmax(x,dim=0)
print(y)

tensor([1., 2., 3.])
tensor([0.0900, 0.2447, 0.6652])

  在代码中需要注意的是:torch.tensor初始化的值默认为int类型,即int64,但是softmax函数没有针对int64类型数据的代码实现,所以应该把softmax要处理的数据类型改为浮点型。

  若对二维及以上的张量进行softmax转换,则dim的值和之前一样,从第0维开始,到最后一维,就越往张量的深处比较。

  举例说明:对三维张量dim=0的softmax转换的代码和运行结果如下:

import torch
x=torch.arange(6).reshape((3,2,1)).float()
print(x)
y=torch.softmax(x,dim=0)
print(y)
tensor([[[0.],
         [1.]],
        [[2.],
         [3.]],
        [[4.],
         [5.]]])
tensor([[[0.0159],
         [0.0159]],
        [[0.1173],
         [0.1173]],
        [[0.8668],
         [0.8668]]])

  对三维张量dim=2的softmax转换的代码和运行结果如下:

import torch
x=torch.arange(6).reshape((3,2,1)).float()
print(x)
y=torch.softmax(x,dim=1)
print(y)

tensor([[[0.],
         [1.]],
        [[2.],
         [3.]],
        [[4.],
         [5.]]])
tensor([[[0.2689],
         [0.7311]],
        [[0.2689],
         [0.7311]],
        [[0.2689],
         [0.7311]]])

torch.stack()

  torch.stack()将列表的多个元素融合,dim和前几个一样,关于其函数功能,直接看例子即可:
  dim=0时,代码和结果如下:

import torch
x=torch.arange(8).reshape((2,2,2))
y=torch.arange(8,16).reshape((2,2,2))
z=torch.stack([x,y],dim=0)
print(x)
print(y)
print(z)

tensor([[[0, 1],
         [2, 3]],
        [[4, 5],
         [6, 7]]])
tensor([[[ 8,  9],
         [10, 11]],
        [[12, 13],
         [14, 15]]])
tensor([[[[ 0,  1],
          [ 2,  3]],
         [[ 4,  5],
          [ 6,  7]]],
        [[[ 8,  9],
          [10, 11]],
         [[12, 13],
          [14, 15]]]])

  从另一个角度观察,可以发现,dim=0时,就是两个张量的简单叠加。

  dim=1时,代码和结果如下:

import torch
x=torch.arange(8).reshape((2,2,2))
y=torch.arange(8,16).reshape((2,2,2))
z=torch.stack([x,y],dim=1)
print(x)
print(y)
print(z)

tensor([[[0, 1],
         [2, 3]],
        [[4, 5],
         [6, 7]]])
tensor([[[ 8,  9],
         [10, 11]],
        [[12, 13],
         [14, 15]]])
tensor([[[[ 0,  1],
          [ 2,  3]],
         [[ 8,  9],
          [10, 11]]],
        [[[ 4,  5],
          [ 6,  7]],
         [[12, 13],
          [14, 15]]]])

  可以发现,dim=1时,从第1维度上进行了拼接,组成了一个新的张量。

总结

  1、首先要理解PyTorch张量的维度,然后理解这些函数在不同维度上的操作就变得十分容易了;
  2、为了避免和PyTorch叙述的dim不同,本文是从第0维开始叙述的,即是日常所说的第一维。

  • 6
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值