torch.max的用法
torch.max(Tensor,index)是求Tensor格式下的最大值包含两部分,一部分是最大值,另一部分是最大值的索引
a = torch.randn(3,3)
print(a)
>>tensor([[ 0.4538, -0.0595, 0.6461],
[-2.0434, 0.5453, -1.2888],
[ 0.6211, -0.7173, 0.2639]])
#每列的最大值以及最大值的索引
print(torch.max(a,0))
>>torch.return_types.max(
values=tensor([0.6211, 0.5453, 0.6461]),
indices=tensor([2, 1, 0]))
print(torch.max(a,0)[0])
>>tensor([0.6211, 0.5453, 0.6461])
print(torch.max(a,0)[1])
>>tensor([2, 1, 0])
#每行的最大值以及最大值的索引
print(torch.max(a,1))
>>torch.return_types.max(
values=tensor([0.6461, 0.5453, 0.6211]),
indices=tensor([2, 1, 0]))
print(torch.max(a,1)[0])
>tensor([0.6461, 0.5453, 0.6211])
print(torch.max(a,1)[1])
>tensor([2, 1, 0])