1结果维度不变
y_hat=torch.tensor([[1,2,3],[4,5,6]])
y=torch.tensor([1,0,0],dtype=torch.bool)
print(y_hat[:,y])
#tensor([[1],[4]]) y_hat[:,y].shape ([2,1])
2 结果维度改变
y_hat=torch.tensor([[1,2,3],[4,5,6]])
y=torch.tensor([1,0,0],dtype=torch.bool)
print(y_hat[range(len(y_hat)),y])
print(y_hat[range(len(y_hat)),y].shape)
#tensor([1, 4]) torch.Size([2])
3 结果维度改变
y_hat=torch.tensor([[1,2,3],[4,5,6]])
y_2=torch.tensor([[1,0,0],[0,0,1]]).bool()
print(y_hat[y_2])
#tensor([1, 6])