np.where和torch.where的使用
两者都是三个输入参数,第一个是判断条件,第二个是符合条件的设置值,第三个是不满足条件的设置值。
区别在于,torch要将设置值全部改为Tensor类型的
#预测概率
y_pred_probs = net(torch.tensor(x_test[0:10]).float()).data
y_pred_probs
#预测类别
y_pred = torch.where(y_pred_probs>0.5,
torch.ones_like(y_pred_probs),torch.zeros_like(y_pred_probs))
y_pred
y_pred = np.where(np.isnan(y_pred_probs), 0, 1)
y_pred