pytorch中mask_select()的用法 import torch a =torch.Tensor([1,2,4,4,5]) print(torch.masked_select(a, a<4)) 1.a<4取出的是布尔值索引(掩码)[1,1,0,0,0,] 2.torch.masked_select(a, a<4):根据a<4的非0掩码从a中取值 print(torch.masked_select(a, a<4)):