示例
代码
from sklearn.metrics import roc_curve, auc
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 10))
def plot_roc(labels, predict_probs, titles):
color = ['r', 'g', 'b', 'y']
shape = ['o', 'v', '^']
for idx, predict_prob in enumerate(predict_probs):
false_positive_rate,true_positive_rate,thresholds=roc_curve(labels, predict_prob)
roc_auc=auc(false_positive_rate, true_positive_rate)
plt.title('ROC')
c = color[idx%len(color)]
s = shape[idx%len(shape)]
plt.plot(false_positive_rate, true_positive_rate,'b',label='AUC K:{} = {:.4}'.format(titles[idx], roc_auc), color=c, marker=s, markevery=20)
plt.legend(loc='lower right')
plt.plot([0,1],[0,1],'r--')
plt.ylabel('TPR')
plt.xlabel('FPR')
plot_roc(pca_test_label, predict_probs)
解释
该代码参数含义为:
- label: 长度为N的列表,二分类的真实标签
- predict_probs:二级列表,每个元素为长度为N的列表,记录的是正类的概率。
- titles:图例中的
K:xxx
的名称列表