model = AlexNet(num_classes=5)
x = torch.randn((5, 3, 224, 224))
y = model(x)
# print(x.shape)
# print(y.shape)
print(y)
print(torch.max(y, dim=1))
print(torch.max(y, dim=1)[1])
输出:
tensor([[-0.0140, 0.0005, -0.0043, -0.0048, 0.0010],
[-0.0233, 0.0032, -0.0008, -0.0094, 0.0052],
[-0.0128, 0.0168, -0.0010, -0.0047, 0.0025],
[-0.0131, 0.0177, -0.0092, -0.0070, 0.0145],
[-0.0258, 0.0088, -0.0095, 0.0063, 0.0106]],
grad_fn=<AddmmBackward0>)
torch.return_types.max(
values=tensor([0.0010, 0.0052, 0.0168, 0.0177, 0.0106], grad_fn=<MaxBackward0>),
indices=tensor([4, 4, 1, 1, 4]))
tensor([4, 4, 1, 1, 4])
所以,torch.max(y, dim=1)[1] 的作用:
在行方向上寻找最大值,并返回索引