from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import torch
prediction_scores = torch.randn([3, 4, 5])
labels = torch.ones([3, 4]) * (-100)
labels[0][1] = 2
labels[1][2] = 1
labels[2][3] = 3
predcit = torch.argmax(prediction_scores, 2)
predcit_calss = torch.masked_select(predcit, labels != -100)
print(predcit_calss)