import torch
import numpy as np
a = torch.rand(2,2,2)
print(a)
values, indices = a.max(2,keepdim=True)
values, indices = torch.max(a, 2,keepdim=True)
c=a-values
b=torch.argmax(a,dim=2).type(torch.uint8)
a[c==0]=2
print(a)
简单方法:
a[a==torch.max(a, 2,keepdim=True)[0]]