- 分别计算PR
def cal_PR(cls, pred_ls, gt_ls):
tn, fn, tp, fp = 0,0,0,0
for i in range(len(gt_ls)):
if int(pred_ls[i]) == int(cls):
if int(gt_ls[i]) == int(cls):
tp += 1
else:
fp += 1
else:
if int(gt_ls[i]) == int(cls):
fn += 1
else:
tn += 1
print("tp:", tp, " fp:", fp," fp:", tn," fp:", fn)
if (int(tp) == 0 and int(fp) == 0):
print("tp and fp are 0, cannot cal P")
return None, None
elif (int(tp) == 0 and int(fn) == 0):
print("tp and fn are 0, cannot cal R")
return None, None
else:
p = float(tp) / (float(tp) + float(fp))
r = float(tp) / (float(tp) + float(fn))
return p, r
- 一个函数搞定
def cal_P_R(pred_ls, gt_ls):
stat_dict = {}
for key in ["0", "1"]:
stat_dict[key] = {"tp": 0, "fp": 0, "fn": 0}
for pred, gt in zip(pred_ls, gt_ls):
if pred == gt:
stat_dict[pred]["tp"] += 1
else:
stat_dict[pred]["fp"] += 1
stat_dict[gt]["fn"] += 1
for key in stat_dict:
if stat_dict[key]["tp"] + stat_dict[key]["fp"] == 0:
P = 0
else:
P = stat_dict[key]["tp"] / (stat_dict[key]["tp"] + stat_dict[key]["fp"])
if stat_dict[key]["tp"] + stat_dict[key]["fn"] == 0:
R = 0
else:
R = stat_dict[key]["tp"] / (stat_dict[key]["tp"] + stat_dict[key]["fn"])
stat_dict[key]["Precision"] = P
stat_dict[key]["Recall"] = R
return stat_dict