计算混淆矩阵代码
'''
产生n×n的分类统计表
参数a:标签图(转换为一行输入),即真实的标签
参数b:score层输出的预测图(转换为一行输入),即预测的标签
参数n:类别数
'''
def fast_hist(a, b, n):
#k为掩膜(去除了255这些点(即标签图中的白色的轮廓),其中的a>=0是为了防止bincount()函数出错)
k = (a >= 0) & (a < n)
#bincount()函数用于统计数组内每个非负整数的个数
return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)
注:可以通过改变这里使得求得的混淆矩阵反一下。
评价指标代码
首先需要计算出TF,TP,FP,FN
import glob
import numpy as np
import cv2
from sklearn.metrics import classification_report
import collections
import os
'''
产生n×n的分类统计表
参数a:标签图(转换为一行输入),即真实的标签
参数b:score层输出的预测图(转换为一行输入),即预测的标签
参数n:类别数
'''
def fast_hist(a, b, n):
#k为掩膜(去除了255这些点(即标签图中的白色的轮廓),其中的a>=0是为了防止bincount()函数出错)
k = (a >= 0) & (a < n)
#bincount()函数用于统计数组内每个非负整数的个数
return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)
# print(collections.Counter(image)) # 统计查看image的类别数目
if __name__ == '__main__':
Pa = []
Cpa = []
Cpa2 = []
Iou = []
Dice = []
lab_path = glob.glob(os.path.join('test_true/', '*.jpg'))
pre_path = glob.glob(os.path.join('test/', '*res.png'))
for i in range(len(lab_path)):
# 处理图像
image_path = lab_path[i]
image = cv2.imread(image_path)
image = image.ravel()
image[np.where(image < 122)] = 0
image[np.where(image >= 122)] = 1
lable_path = pre_path[i]
lable = cv2.imread(lable_path)
lable = lable.ravel()
lable[np.where(lable == 255)] = 1
# 计算混淆矩阵
hist = fast_hist(image, lable, 2)
TP = hist[0][0]
TN = hist[1][1]
FP = hist[1][0]
FN = hist[0][1]
# 计算各项指标
pa = (TN+TP)/(TN+TP+FN+FP)
cpa = TN/(TN+FN)
cpa2 = TP/(TP+FP)
iou = FN/(FN+FP+TN)
dice = 2*TP/(2*TP+FN+FP)
Pa.append(pa)
Cpa.append(cpa)
Cpa2.append(cpa2)
Iou.append(iou)
Dice.append(dice)
# 列表转为Numpy
Pa_num = np.array(Pa)
Cpa_num = np.array(Cpa)
Cpa_num2 = np.array(Cpa2)
Iou_num = np.array(Iou)
Dice_num = np.array(Dice)
print('图片的平均Pa:{:.4f}±{:.4f}'.format(Pa_num.mean(), Pa_num.std()))
print('白块的平均Cpa:{:.4f}±{:.4f}'.format(Cpa_num.mean(), Cpa_num.std()))
print('类别平均像素准确率Mpa:{:.4f}'.format((Cpa_num.mean()+Cpa_num2.mean())/2))
print('白块类别平均IoU:{:.4f}±{:.4f}'.format(Iou_num.mean(), Iou_num.std()))
print('Dice系数:{:.4f}±{:.4f}'.format(Dice_num.mean(), Dice_num.std()))