一、一个参数时的torch.argmax函数
torch.argmax(input) -> LongTensor
该函数返回输入张量中所有元素中最大值的索引。(如果有多个最大值,则返回第一个最大值的索引)
注:索引从0开始。
参数:
- input (Tensor) - 输入张量。
实例:
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]