对分割标签进行处理解码处理
import os
import torch
import numpy as np
import torch.nn.functional as F
x = np.array([[0.2, 0.3], [0.4, 0.5],[0.7, 0.6]]) # 输入的是array 类型
y = np.array([[1, 0], [1, 1],[0, 1]])
y_pred = torch.from_numpy(x).float()
y_label = torch.from_numpy(y).float()
# print(y_label.shape)
pred = F.softmax(y_pred, dim=-1).numpy()
print(pred)
print("使用最值:")
print(pred.argmax(axis=-1))
print("使用阈值:")
pred = np.where(pred>0.5)
print(pred[1])
注意 np.where是的使用
得到的最值或者满足条件的索引,得到的是元组形式,这里我们只使用列。
pred = np.where(pred>0.5)
print(pred[1])
以上我们是用的是
np.argwhere和np.where 详情见
https://www.ngui.cc/el/1148979.html?action=onClick