1 官方文档介绍
1.1 torch.argmax()介绍
返回最大值的索引下标
函数:
torch.argmax(input, dim, keepdim=False) → LongTensor
返回值:
Returns the indices of the maximum values of a tensor across a dimension.
参数:
input (Tensor) – the input tensor.
dim (int) – the dimension to reduce. If None, the argmax of the flattened input is returned.
keepdim (bool) – whether the output tensor has dim retained or not. Ignored if dim=None.
1.2 torch.argmin()介绍
返回最小值的索引下标
函数:
torch.argmin(input, dim, keepdim=False) → LongTensor
返回值:
Returns the indices of the mimimum values of a tensor across a dimension.
参数:
input (Tensor) – the input tensor.
dim (int) – the dimension to reduce. If None, the argmax of the flattened input is returned.
keepdim (bool) – whether the output tensor has dim retained or not. Ignored if dim=None.
2 代码示例
2.1 torch.argmax()代码示例
>>> import torch
>>> Matrix = torch.randn(2,2,2)
>>> print(Matrix)
tensor([[[ 0.3772, -0.1143],
[ 0.2217, -0.1897]],
[[ 0.1488, -0.8758],
[ 1.7734, -0.5929]]])
>>> print(Matrix.argmax(dim=0))
tensor([[0, 0],
[1, 0]])
>>> print(Matrix.argmax(dim=1))
tensor([[0, 0],
[1, 1]])
>>> print(Matrix.argmax(dim=2))
tensor([[0, 0],
[0, 0]])
>>> print(Matrix.argmax())
tensor(6)
2.2 torch.argmin()代码示
>>> import torch
>>> Matrix = torch.randn(2,2,2)
>>> print(Matrix)
tensor([[[ 0.5821, 0.2889],
[ 0.4669, -0.3135]],
[[-0.4567, 0.2975],
[-1.5084, 0.7320]]])
>>> print(Matrix.argmin(dim=0))
tensor([[1, 0],
[1, 0]])
>>> print(Matrix.argmin(dim=1))
tensor([[1, 1],
[1, 0]])
>>> print(Matrix.argmin(dim=2))
tensor([[1, 1],
[0, 0]])
>>> print(Matrix.argmin())
tensor(6)