函数:torch.max(input, dim, keepdim=False, out=None)
返回值:Tensor, LongTensor
作用:输入tensor,按维度返回其最大值,并返回其最大维度的索引。
创建2个张量a和b
a = torch.tensor([[10, 5, 1], [2, 4, 3]])
b = torch.tensor([[6, 8, 9], [1, 4, 2]])
print(a)
print(b)
tensor([[10, 5, 1],
[ 2, 4, 3]])
tensor([[6, 8, 9],
[1, 4, 2]])
输出其最大值
print(torch.max(a))
print(torch.max(b))
tensor(10)
tensor(9)
按维度输出最大值,其中0是按列索引,1是按行索引,在输出最大值的同时,输出最大值所在的索引的位置。
*在深度学习中,知道了最大值索引位置,就知道了one-hot编码,知道了one-hot编码,就可以进行类别预测了。
print(torch.max(a, 0))
print("-------------------------")
print(torch.max(b, 1))
torch.return_types.max(
values=tensor([10, 5, 3]),
indices=tensor([0, 0, 1]))
-------------------------
torch.return_types.max(
values=tensor([9, 4]),
indices=tensor([2, 1]))
比较a和b,按维度输出最大值
c = torch.max(a, b)
print(c)
tensor([[10, 8, 9],
[ 2, 4, 3]])
如果觉得有帮助,欢迎点赞+收藏,笔芯~