先来看torch.argmax和torch.max的区别
D = torch.tensor([1, 2, 3, 4, 5, 6])
print(D.max(0))
print("------------------------------------------------------")
print(D.argmax(0))—————输出—————
torch.return_types.max(
values=tensor(6),
indices=tensor(5))
------------------------------------------------------
tensor(5)
torch.max求指定维度下的最大值,并返回最大值及其位置,torch.argmax求指定维度下的最大值的位置并返回位置
和torch.max差不多,对哪一维度求argmax哪一维度就消失合并
testData = torch.tensor(
[
[
[[1, 9], [2, 3], [11, 3], [55, 4]],
[[4, 2], [3, 1], [4, 2], [6, 1]],
[[1, 2], [5, 33], [3, 5], [1, 9]],
[[1, 5], [2, 7], [8, 3], [6, 4]],
[[32, 4], [23, 3], [11, 2], [1, 6]],
[[2, 6], [5, 8], [3, 14], [1, 2]],
],
[
[[1, 9], [2, 3], [11, 3], [55, 4]],
[[4, 2], [3, 1], [4, 2], [6, 1]],
[[1, 2], [5, 33], [3, 5], [1, 9]],
[[1, 5], [2, 7], [8, 3], [6, 4]],
[[32, 4], [23, 3], [11, 2], [1, 6]],
[[2, 6], [5, 8], [3, 14], [1, 2]],
],
]
)
print(testData.argmax(2))
print(testData.argmax(2).shape)
print(testData.argmax(1))
print(testData.argmax(1).shape)
如上例子,先来看对2维度求 argmax的输出
tensor([
[[3, 0],[3, 0],[1, 1],[2, 1],[0, 3],[1, 2]],
[[3, 0],[3, 0],[1, 1],[2, 1],[0, 3],[1, 2]]
])
torch.Size([2, 6, 2])
原本的shape为torch.Size([2, 6, 4, 2]),对二维度进行求argmax后变成了torch.Size([2, 6, 2])二维度消失,于是二维度需要合并,所以就要把[4,2]变成[2],所以用所有的在三维度的进行每个位置上的数据进行对比,并选出最大的,组成新的shape为[2]的一个维度