Pytorch中,经常会使用torch.max(a,dim)对tensor进行处理,特别是针对多维的tensor,就感觉对dim的选取似懂非懂。
一、针对1维的数据
这个比较好理解,就是针对1维的数据取最大值,返回一个tensor类型的数值,和该数值对应的下标,合起来就是一个tuple类型。
import torch
a = torch.randn(3) #随机生成数组
max=torch.max(a,dim=0) #默认dim=0
print("a:\n", a)
print("************************************************")
print("max(a):", max) #输出最大值,以及对应的索引,tuple类型
print("max(a)_value:", max[0]) #只返回tensor数值
print("max(a)_index:", max[1]) #只返回对应的索引
<<
a:
tensor([ 1.5691, -0.7801, -1.2262])
************************************************
max(a): (tensor(1.5691), tensor(0))
max(a)_value: tensor(1.5691)
max(a)_index: tensor(0)
二、针对2维数据
此时的tensor(行,列),可以理解为一张特征图,dim=0就是行间进行比较,dim=1就是列间进行比较
import torch
a = torch.randn(2,3) #随机生成数组
max_0=torch.max(a,dim=0) #针对第1个元素“2”,对应的是行
max_1=torch.max(a,dim=1) #针对第2个元素“3”,对应的是列
print("a:\n", a)
print("************************************************")
print("max(a)_0:", max_0) #dim=0,行与行之间进行比较,所以返回每一列的最大值
print("max(a)_1:", max_1) #dim=1,列与列之间进行比较,所以返回每一行的最大值
<<
a:
tensor([[ 0.1734, -0.7264, 0.6981],
[ 0.0859, 1.2663, -0.0851]])
************************************************
max(a)_0: (tensor([ 0.1734, 1.2663, 0.6981]), tensor([ 0, 1, 0]))
max(a)_1: (tensor([ 0.6981, 1.2663]), tensor([ 2, 1]))
三、针对3维数据
此时的tensor(通道,行,列),可以理解为很多张特征图叠加在一起,dim=0就是通道间进行比较,dim=1就是行间进行比较,dim=2就是列间进行比较。
import torch
a = torch.randn(2,3,4) #随机生成数组
max_0=torch.max(a,dim=0) #针对第1个元素“2”,对应的是通道
max_1=torch.max(a,dim=1) #针对第2个元素“3”,对应的是行
max_2=torch.max(a,dim=2) #针对第2个元素“4”,对应的是列
print("a:\n", a)
print("************************************************")
print("max(a)_0:", max_0) #dim=0,通道间进行比较,所以返回每一张特征图,同一像素位置上的最大值
print("max(a)_1:", max_1) #dim=1,行与行之间进行比较,所以返回每一张特征图,每一列的最大值
print("max(a)_2:", max_1) #dim=1,列与列之间进行比较,所以返回每一张特征图,每一行的最大值
<<
a:
tensor([[[ 0.5323, 1.5229, -0.6122, 0.6054],
[ 1.2424, -1.6005, 0.0779, 0.9227],
[-0.6340, -0.5770, -0.1672, 0.3598]],
[[-0.3770, -0.4992, 1.8444, -1.1040],
[ 1.2238, 0.7283, -1.6462, 0.0325],
[-0.3555, -0.2599, 1.5741, 1.0683]]])
************************************************
max(a)_0: (tensor([[ 0.5323, 1.5229, 1.8444, 0.6054],
[ 1.2424, 0.7283, 0.0779, 0.9227],
[-0.3555, -0.2599, 1.5741, 1.0683]]), tensor([[ 0, 0, 1, 0],
[ 0, 1, 0, 0],
[ 1, 1, 1, 1]]))
max(a)_1: (tensor([[ 1.2424, 1.5229, 0.0779, 0.9227],
[ 1.2238, 0.7283, 1.8444, 1.0683]]), tensor([[ 1, 0, 1, 1],
[ 1, 1, 0, 2]]))
max(a)_2: (tensor([[ 1.2424, 1.5229, 0.0779, 0.9227],
[ 1.2238, 0.7283, 1.8444, 1.0683]]), tensor([[ 1, 0, 1, 1],
[ 1, 1, 0, 2]]))
四、针对4维数据
此时的tensor(batch_size,channel, 行,列),可以理解为一个批次的训练数据的集合,dim=0,是批次间的比较;dim=1,是每个批次,自己通道间的比较;dim=2对应的行比较;dim=3对应的是列比较
import torch
a = torch.randn(1,2,3,4) #随机生成数组
max_0=torch.max(a,dim=0) #针对第1个元素“1”,对应的是batch_size
max_1=torch.max(a,dim=1) #针对第2个元素“2”,对应的是通道
max_2=torch.max(a,dim=2) #针对第2个元素“3”,对应的是行
max_3=torch.max(a,dim=3) #针对第2个元素“4”,对应的是列
print("a:\n", a)
print("************************************************")
print("max(a)_0:", max_0) #dim=0,多个输入之间进行比较,返回的是每次输入时,每一张特征图,同一像素位置上的最大值
#一般为1,一张张读的
print("max(a)_1:", max_1) #dim=1,通道间进行比较,所以返回每一张特征图,同一像素位置上的最大值
print("max(a)_2:", max_2) #dim=2,行与行之间进行比较,所以返回每一张特征图,每一列的最大值
print("max(a)_3:", max_3) #dim=3,列与列之间进行比较,所以返回每一张特征图,每一行的最大值
<<
a:
tensor([[[[ 0.6404, 0.5116, -0.5562, 2.2283],
[-0.6507, 0.4440, 0.8723, -0.6538],
[ 0.0352, 1.0738, 0.2382, 0.7763]],
[[-0.5208, 0.4854, -0.0950, 1.3100],
[ 0.0433, -0.6561, 0.1956, -0.3584],
[-1.0031, -1.7104, 0.6768, -0.1648]]]])
************************************************
max(a)_0: (tensor([[[ 0.6404, 0.5116, -0.5562, 2.2283],
[-0.6507, 0.4440, 0.8723, -0.6538],
[ 0.0352, 1.0738, 0.2382, 0.7763]],
[[-0.5208, 0.4854, -0.0950, 1.3100],
[ 0.0433, -0.6561, 0.1956, -0.3584],
[-1.0031, -1.7104, 0.6768, -0.1648]]]), tensor([[[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]],
[[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]]]))
max(a)_1: (tensor([[[ 0.6404, 0.5116, -0.0950, 2.2283],
[ 0.0433, 0.4440, 0.8723, -0.3584],
[ 0.0352, 1.0738, 0.6768, 0.7763]]]), tensor([[[ 0, 0, 1, 0],
[ 1, 0, 0, 1],
[ 0, 0, 1, 0]]]))
max(a)_2: (tensor([[[ 0.6404, 1.0738, 0.8723, 2.2283],
[ 0.0433, 0.4854, 0.6768, 1.3100]]]), tensor([[[ 0, 2, 1, 0],
[ 1, 0, 2, 0]]]))
max(a)_3: (tensor([[[ 2.2283, 0.8723, 1.0738],
[ 1.3100, 0.1956, 0.6768]]]), tensor([[[ 3, 2, 1],
[ 3, 2, 2]]]))