一、torch.max(input, dim) 函数
output = torch.max(input, dim)
输入:
input 是一个tensor
dim 是 max 函数索引的维度,dim 为 0 时返回每列最大值,dim 为 1 时返回每行最大值
输出:
函数会返回两个tensor,第一个 tensor 是某维度(dim)上的最大值;第二个 tensor 是最大值的索引(位置)。
二、实例
import torch
a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
print(a)
# dim 为 1 时返回每行最大值
print(torch.max(a, 1))
print(torch.max(a, 1)[1].numpy())
输出结果:
tensor([[ 1, 5, 62, 54],
[ 2, 6, 2, 6],
[ 2, 65, 2, 6]])
torch.return_types.max(
values=tensor([62, 6, 65]),
indices=tensor([2, 1, 1]))
[2 1 1]