在pytorch中看到torch.max()函数,图片数据的维度一般是[b,c,h,w]
例子就忽略b维度,随机初始化一个三维度的tensor,然后测试在每个维度上使用torch.max()的结果。
1.首先是w维度,也就是三维向量列的维度
import torch
input=torch.randn(2,3,3)
print(input)
input_max=torch.max(input,dim=2,keepdim=True)[0]
print(input_max)
print(input_max.shape)
上面的原始tensor数据是2x3x3的,也就是2块,3行,3列。如果在w维度执行torch.max()函数,得到的tensor数据就是在每一块的每一行选出最大的数值,所以每一块中只有一列数据,所以得出的结果是2x3x1的tensor。
为什么是在每一块的每一行中选出最大值呢?(因为同一行的数据是属于列维度的)
对于一个2x2的二维向量[[1,2,3],[4,5,6]],1,2,3、4,5,6是处于同一维度的(列维度),而1,4、2,5、3,6是处于同一维度的(行维度)。而torch.max()就是找出指定维度的最大值。上例中要取w维度的最大值,所以是在每一块的每一行中找出最大值。
2.首先是h维度,也就是三维向量行的维度
只需要将上面代码中torch.max()的dim参数的值改为1即可。
在h维度上执行,就是在每一块的每一列中找出最大值。所以每一块中只有一行数据,所以结果是2x1x3的tensor。
这三个数据是属于原数据的行的维度的,所以在行的维度上执行torch.max()函数,是在每一列中寻找最大值。
3. 最后是c维度,也就是三维向量块(根据上述的数据随意起的名字)的维度
只需要将上面代码中torch.max()的dim参数的值改为0即可。
在该维度上执行,就是要在各个块中找出最大的一块,所以得出的结果只有一块也就是1x3x3的tensor。
实际上是每一个块对应位置上的数据比较,选择最大值来拼成最后的结果。