1、torch.max()
torch.max(input) → Tensor
返回 input tensor 中所有元素的最大值。
torch.max(input, dim, keepdim=False) → output tensors (max, max_indices)
参数:
input:输入的 tensor。
dim:按什么维度求最大值(2D中,0代表按列求最大值,1代表按行求最大值)。
keepdim:是否保持 input tensor 的维度,True 代表 out tensor 与 input tensor 的维度相同,False 代表 out tensor 与 input tensor 的维度不同。
返回值:
返回一个 namedtuple (values, indices),values表示指定维度的最大值,indices表示最大值所在的索引,如果给定维度有多个最大值,返回第一个最大值所在的索引。
torch.max(output, 2, keepdim=True)[1]
torch.max 返回一个数组
[1]:就是返回数组的第二个值
[0]:就是返回数组的第一个值
https://cloud.tencent.com/developer/article/1914026