神奇的max函数,你品,你细品
import torch
a = torch.randn(1,10)
print(a)
print("-------------------")
out = torch.max(a,1)
print(out)
print("-------------------")
out = torch.max(a,1)[1]
print(out)
print("-------------------")
out = torch.max(a,1)[1].data
print(out)
print("-------------------")
out = torch.max(a,1)[1].data.numpy()
print(out)
print("-------------------")
out = torch.max(a,1)[1].data.numpy().squeeze(0)
print(out)
_,out = torch.max(a,1)
print(out)
输出
tensor([[ 0.3525, 0.5193, 0.3436, -0.0077, 0.3875, -1.3341, 0.7200, 0.3214, -1.1317, 0.8075]])
-------------------
torch.return_types.max(
values=tensor([0.8075]),
indices=tensor([9]))
-------------------
tensor([9])
-------------------
tensor([9])
-------------------
[9]
-------------------
9
-------------------
tensor([9])