Pytorch的torch.where()函数和numpy中的np.where()函数使用方法完全是一致的。torch.where()需要三个参数:torch.where(第一个为判断条件,第二个为满足条件设置的值,第三个为不满足条件设置的值)。等同于一个基本判断结构:
if 条件:
值1
else:
值2
example:
y_pred = torch.where(y_pred > 0.5,
torch.ones_like(y_pred, dtype=torch.float32),
torch.zeros_like(y_pred, dtype=torch.float32))
而numpy的np.where()用法完全一样,唯一的不同在于torch.where()的返回值数据类型为Tensor,而np.where()的返回值由具体情况确定。