关于torch.max(a,dim)中维度的选取

Pytorch中,经常会使用torch.max(a,dim)对tensor进行处理,特别是针对多维的tensor,就感觉对dim的选取似懂非懂。

一、针对1维的数据

这个比较好理解,就是针对1维的数据取最大值,返回一个tensor类型的数值,和该数值对应的下标,合起来就是一个tuple类型。

import torch

a = torch.randn(3) #随机生成数组
max=torch.max(a,dim=0) #默认dim=0
print("a:\n", a)
print("************************************************")
print("max(a):", max) #输出最大值,以及对应的索引,tuple类型
print("max(a)_value:", max[0]) #只返回tensor数值
print("max(a)_index:", max[1]) #只返回对应的索引


<<
a:
 tensor([ 1.5691, -0.7801, -1.2262])
************************************************
max(a): (tensor(1.5691), tensor(0))
max(a)_value: tensor(1.5691)
max(a)_index: tensor(0)

二、针对2维数据

此时的tensor(行,列),可以理解为一张特征图,dim=0就是行间进行比较,dim=1就是列间进行比较

import torch

a = torch.randn(2,3) #随机生成数组
max_0=torch.max(a,dim=0) #针对第1个元素“2”,对应的是行
max_1=torch.max(a,dim=1) #针对第2个元素“3”,对应的是列
print("a:\n", a)
print("************************************************")
print("max(a)_0:", max_0)  #dim=0,行与行之间进行比较,所以返回每一列的最大值
print("max(a)_1:", max_1)  #dim=1,列与列之间进行比较,所以返回每一行的最大值

<<
a:
 tensor([[ 0.1734, -0.7264,  0.6981],
        [ 0.0859,  1.2663, -0.0851]])
************************************************
max(a)_0: (tensor([ 0.1734,  1.2663,  0.6981]), tensor([ 0,  1,  0]))
max(a)_1: (tensor([ 0.6981,  1.2663]), tensor([ 2,  1]))

三、针对3维数据

此时的tensor(通道,行,列),可以理解为很多张特征图叠加在一起,dim=0就是通道间进行比较,dim=1就是行间进行比较,dim=2就是列间进行比较。

import torch

a = torch.randn(2,3,4) #随机生成数组
max_0=torch.max(a,dim=0) #针对第1个元素“2”,对应的是通道
max_1=torch.max(a,dim=1) #针对第2个元素“3”,对应的是行
max_2=torch.max(a,dim=2) #针对第2个元素“4”,对应的是列
print("a:\n", a)
print("************************************************")
print("max(a)_0:", max_0)  #dim=0,通道间进行比较,所以返回每一张特征图,同一像素位置上的最大值
print("max(a)_1:", max_1)  #dim=1,行与行之间进行比较,所以返回每一张特征图,每一列的最大值
print("max(a)_2:", max_1)  #dim=1,列与列之间进行比较,所以返回每一张特征图,每一行的最大值

<<
a:
 tensor([[[ 0.5323,  1.5229, -0.6122,  0.6054],
         [ 1.2424, -1.6005,  0.0779,  0.9227],
         [-0.6340, -0.5770, -0.1672,  0.3598]],

        [[-0.3770, -0.4992,  1.8444, -1.1040],
         [ 1.2238,  0.7283, -1.6462,  0.0325],
         [-0.3555, -0.2599,  1.5741,  1.0683]]])
************************************************
max(a)_0: (tensor([[ 0.5323,  1.5229,  1.8444,  0.6054],
        [ 1.2424,  0.7283,  0.0779,  0.9227],
        [-0.3555, -0.2599,  1.5741,  1.0683]]), tensor([[ 0,  0,  1,  0],
        [ 0,  1,  0,  0],
        [ 1,  1,  1,  1]]))
max(a)_1: (tensor([[ 1.2424,  1.5229,  0.0779,  0.9227],
        [ 1.2238,  0.7283,  1.8444,  1.0683]]), tensor([[ 1,  0,  1,  1],
        [ 1,  1,  0,  2]]))
max(a)_2: (tensor([[ 1.2424,  1.5229,  0.0779,  0.9227],
        [ 1.2238,  0.7283,  1.8444,  1.0683]]), tensor([[ 1,  0,  1,  1],
        [ 1,  1,  0,  2]]))

四、针对4维数据

此时的tensor(batch_size,channel, 行,列),可以理解为一个批次的训练数据的集合,dim=0,是批次间的比较;dim=1,是每个批次,自己通道间的比较;dim=2对应的行比较;dim=3对应的是列比较

import torch

a = torch.randn(1,2,3,4) #随机生成数组
max_0=torch.max(a,dim=0) #针对第1个元素“1”,对应的是batch_size
max_1=torch.max(a,dim=1) #针对第2个元素“2”,对应的是通道
max_2=torch.max(a,dim=2) #针对第2个元素“3”,对应的是行
max_3=torch.max(a,dim=3) #针对第2个元素“4”,对应的是列

print("a:\n", a)
print("************************************************")

print("max(a)_0:", max_0)  #dim=0,多个输入之间进行比较,返回的是每次输入时,每一张特征图,同一像素位置上的最大值
                             #一般为1,一张张读的
print("max(a)_1:", max_1)  #dim=1,通道间进行比较,所以返回每一张特征图,同一像素位置上的最大值
print("max(a)_2:", max_2)  #dim=2,行与行之间进行比较,所以返回每一张特征图,每一列的最大值
print("max(a)_3:", max_3)  #dim=3,列与列之间进行比较,所以返回每一张特征图,每一行的最大值

<<

a:
 tensor([[[[ 0.6404,  0.5116, -0.5562,  2.2283],
          [-0.6507,  0.4440,  0.8723, -0.6538],
          [ 0.0352,  1.0738,  0.2382,  0.7763]],

         [[-0.5208,  0.4854, -0.0950,  1.3100],
          [ 0.0433, -0.6561,  0.1956, -0.3584],
          [-1.0031, -1.7104,  0.6768, -0.1648]]]])
************************************************
max(a)_0: (tensor([[[ 0.6404,  0.5116, -0.5562,  2.2283],
         [-0.6507,  0.4440,  0.8723, -0.6538],
         [ 0.0352,  1.0738,  0.2382,  0.7763]],

        [[-0.5208,  0.4854, -0.0950,  1.3100],
         [ 0.0433, -0.6561,  0.1956, -0.3584],
         [-1.0031, -1.7104,  0.6768, -0.1648]]]), tensor([[[ 0,  0,  0,  0],
         [ 0,  0,  0,  0],
         [ 0,  0,  0,  0]],

        [[ 0,  0,  0,  0],
         [ 0,  0,  0,  0],
         [ 0,  0,  0,  0]]]))
max(a)_1: (tensor([[[ 0.6404,  0.5116, -0.0950,  2.2283],
         [ 0.0433,  0.4440,  0.8723, -0.3584],
         [ 0.0352,  1.0738,  0.6768,  0.7763]]]), tensor([[[ 0,  0,  1,  0],
         [ 1,  0,  0,  1],
         [ 0,  0,  1,  0]]]))
max(a)_2: (tensor([[[ 0.6404,  1.0738,  0.8723,  2.2283],
         [ 0.0433,  0.4854,  0.6768,  1.3100]]]), tensor([[[ 0,  2,  1,  0],
         [ 1,  0,  2,  0]]]))
max(a)_3: (tensor([[[ 2.2283,  0.8723,  1.0738],
         [ 1.3100,  0.1956,  0.6768]]]), tensor([[[ 3,  2,  1],
         [ 3,  2,  2]]]))

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值