对于多分类问题,ROC曲线的获取主要有两种方法:
假设测试样本个数为m,类别个数为n。在训练完成后,计算出每个测试样本的在各类别下的概率或置信度,得到一个[m, n]形状的矩阵P,每一行表示一个测试样本在各类别下概率值(按类别标签排序)。相应地,将每个测试样本的标签转换为类似二进制的形式,每个位置用来标记是否属于对应的类别(也按标签排序,这样才和前面对应),由此也可以获得一个[m, n]的标签矩阵L。
①方法一:每种类别下,都可以得到m个测试样本为该类别的概率(矩阵P中的列)。所以,根据概率矩阵P和标签矩阵L中对应的每一列,可以计算出各个阈值下的假正例率(FPR)和真正例率(TPR),从而绘制出一条ROC曲线。这样总共可以绘制出n条ROC曲线。最后对n条ROC曲线取平均,即可得到最终的ROC曲线。
②方法二:
首先,对于一个测试样本:1)标签只由0和1组成,1的位置表明了它的类别(可对应二分类问题中的‘’正’’),0就表示其他类别(‘’负‘’);2)要是分类器对该测试样本分类正确,则该样本标签中1对应的位置在概率矩阵P中的值是大于0对应的位置的概率值的。基于这两点,将标签矩阵L和概率矩阵P分别按行展开,转置后形成两列,这就得到了一个二分类的结果。所以,此方法经过计算后可以直接得到最终的ROC曲线。
上面的两个方法得到的ROC曲线是不同的,当然曲线下的面积AUC也是不一样的。 在python中,方法1和方法2分别对应sklearn.metrics.roc_auc_score函数中参数average值为’macro’和’micro’的情况。下面参考sklearn官网提供的例子,对两种方法进行实现。
上代码:
# # 绘制roc曲线 # # y_test_one_hot = label_binarize(y_test_cls, np.arange(3)) # 将标签二值化 y_predict_one_hot = y_logits_cls plt.figure() # 绘图 mpl.rcParams['font.sans-serif'] = u'SimHei' mpl.rcParams['axes.unicode_minus'] = False # FPR就是横坐标,TPR就是纵坐标 # 计算ROC fpr_dict, tpr_dict, roc_auc = dict(), dict(), dict() for i in range(3): # 计算每一个标签的假正例率(fpr)和真正例率(tpr) fpr_dict[i], tpr_dict[i], _ = roc_curve(y_test_one_hot[:, i], y_predict_one_hot[:, i]) roc_auc[i] = auc(fpr_dict[i], tpr_dict[i]) # 两种画法: # 方法一:将所有的标签进行二值化处理后,如[[0,0,1],[0,1,0]] 转成[0,0,1,0,1,0] 转成二分类进行求解 fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(y_test_one_hot.ravel(), y_predict_one_hot.ravel()) roc_auc["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"]) # # 方法二: 将每个标签的fpr和tpr进行累加除以种类数,即画出平均后的roc曲面 # n_classes = 3 # from scipy import interp # all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(n_classes)])) # # Then interpolate all ROC curves at this points # mean_tpr = np.zeros_like(all_fpr) # for i in range(n_classes): # mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i]) # # Finally average it and compute AUC # mean_tpr /= n_classes # fpr_dict["macro"] = all_fpr # tpr_dict["macro"] = mean_tpr # roc_auc["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"]) # print(roc_auc) # 显示到当前界面,保存为svm.png lw = 2 # plt.plot(fpr_dict[2], tpr_dict[2], color='darkorange', # 画关于正面的roc曲面 # lw=lw, label='ROC curve (area = %0.3f)' % roc_auc["micro"]) plt.plot(fpr_dict["micro"], tpr_dict["micro"], color='darkorange', lw=lw, label='ROC curve (area = %0.3f)' % roc_auc["micro"]) plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic example') plt.legend(loc="lower right") plt.title(u'text_rnnROC和AUC', fontsize=17) path = os.path.join(file_path, "img") if not os.path.exists(path): os.makedirs(path) plt.savefig(os.path.join(file_path, "img", "{}的ROC和AUC.png".format("model_" + str(config.model_num) + "_")))