torch.max被广泛应用在评价model的最终预测性能中,其实这个问题大家已经总结得挺详细了,例如:
https://blog.csdn.net/liuweiyuxiang/article/details/84668269
https://www.cnblogs.com/Archer-Fang/p/10651029.html
但是正如前面一个博文里网友评论的那样,似乎拿行、列来区分不太妥当。当然我也没有想到更好的办法来总结,根据例子就很快可以掌握了:
1. torch.max(a)是返回a中的最大值:
a=torch.tensor([[-2.1456, -0.6380, 1.3625],
[-1.0394, -0.9641, -0.3667]])
print(torch.max(a))将得到:
tensor(1.3625)
(另外,怎么把这个转成int呢?加上.item()即可)
另外torch.max(a)==a.max()
2. torch.max(a,1)返回的是每一行的最大值,还有最大值所在的索引:
torch.return_types.max(
values=tensor([ 1.3625, -0.3667]),
indices=tensor([2, 2]))
当然我们很多情况下只关心index(例如计算accuracy的时候),那么这时候用
torch.max(a,1)[1] 或者 a.max(1)[1] 取出来即可:
tensor([2, 2])
再加上.numpy()就可以转成array:
print(a.max(1)[1].numpy())得到:
[2 2]
基本的使用方法就是这些,但是有一个问题,为什么
torch.max(a,1)是每一行的最大值而torch.max(a,0)是每一列的最大值呢?
例如上面这个例子,print(torch.max(a,0))的输出是:
torch.return_types.max(
values=tensor([-1.0394, -0.6380, 1.3625]),
indices=tensor([1, 0, 0]))
实际上可以这样理解:0指的是在dimension 0中,各个vector之间比较,取到vector每一维的最大值。1指的是dimension 1中,每个逗号之间的元素进行比较,例如在[-2.1456, -0.6380, 1.3625]几个数中间进行比较,得到的最大值就是一个标量了,然后这些最大值拼接成一个vector。
所以可以思考一下这种情况(虽然遇到的很少,所以其实我们可以按照0,1分别对应行和列来理解):
a=torch.tensor([[[-0.2389, -0.8487, -1.5907, 0.0732],
[-0.2159, 1.1064, -1.1317, 0.6457],
[ 0.8191, 1.0146, 1.0241, 0.7042]],
[[-0.8285, 0.3628, 1.4678, 0.7984],
[ 0.1009, -0.3307, -0.8245, 0.0044],
[-1.5041, 0.5067, 0.4085, 0.2126]]])
在这个时候,print(torch.max(a,0))得到的结果是:
values=tensor([[-0.2389, 0.3628, 1.4678, 0.7984],
[ 0.1009, 1.1064, -0.8245, 0.6457],
[ 0.8191, 1.0146, 1.0241, 0.7042]]),
indices=tensor([[0, 1, 1, 1],
[1, 0, 1, 0],
[0, 0, 0, 0]]))
大家可以看看是不是这么回事。