if __name__ == '__main__': # pdist = torch.nn.PairwiseDistance() a = torch.tensor([1,2,3,5,6]) b = torch.tensor([1,2,3,4,5]) idx = np.where(a!=b)[0] print(a,b,idx)
tensor([1, 2, 3, 5, 6]) tensor([1, 2, 3, 4, 5]) [3 4]
原因是因为会默认将tensor转换成numpy。当然也可以直接用torch.where实现。