上代码图:
import torch
a = torch.tensor([[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
]
,
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
下面开始实验:
b = torch.argmax(a, dim=0)
predict_y = torch.max(a, dim=0)
print(predict_y)
print(b)
![在这里插入图片描述](https://img-blog.csdnimg.cn/104ed62a633d41269f3dd19e1f861438.png#pic_center)
结论:argmax是max函数的第二维度罢了。