pytorch按条件筛选Tensor 情景描述 已知有一个二维Tensor,规模为(3,2)。要求筛选后的tensor每行的最大值都大于5。 # probs 是一个二维Tensor import torch def filtrate(probs, criterion=5): mask = torch.max(probs,dim=1)[