统计第2维数据全部小于3的
import torch
data_a=torch.arange(0, 6)
data= data_a.reshape([3,2])
data_b= data_a.clone().reshape([3,2])
data[data<3]=0
print(data)
data2=torch.sum(data,1)
print( data2==0)
print(data_b[data2==0])