torch.max(X,dim=1)是对行取最大值 dim=1,表面上感觉时对列取最大值,测试一下: X = torch.tensor([[1.0, 1.0], [-1.0, -1.0]]) result,indices = torch.max(X,dim=1) print(result) print(indices) tensor([ 1., -1.]) 如果是对列取最大值,结果应该都是1,因此是对行取最大值 tensor([0, 0])