【机器学习系列】如何将多条ROC曲线画在一张图里,并解决文本遮挡问题

有的时候我们需要将ROC曲线输出在同一张图中,这样可以更加直观地对比模型;并且我们常常会遇到在图形中有文字相互遮挡的问题,我们可以用adjustText中的adjust_text来实现文本不相互遮挡并添加箭头的功能。

 定义多条roc曲线画图函数

def multi_models_roc(names, prob_results, colors,linestyles, y_test, save=True, dpin=100):
   """
   将多个机器模型的roc图输出到一张图上

   Args:
       names: list, 多个模型的名称
       prob_results: 使用模型预测的概率值(predict_proba()函数的返回值)
       colors: 想绘制的曲线的颜色列表
       linestyles: 想绘制的曲线的线型
       save: 选择是否将结果保存(默认为png格式)

   Returns:
       返回图片对象plt
   """
        
    plt.figure(figsize=(10, 10), dpi=dpin)
    from adjustText import adjust_text
    texts = []
    for (name, result, colorname,linestylename) in zip(names, prob_results, colors, linestyles):
        y_test_predprob = result[:,1]
        fpr, tpr, thresholds = roc_curve(y_test, y_test_predprob)
        
        optimal_th, optimal_point = Find_Optimal_Cutoff(TPR=tpr, FPR=fpr, threshold=thresholds)
#         plt.plot(optimal_point[0], optimal_point[1], marker='o', color='r')
#         texts.append(plt.text(optimal_point[0], optimal_point[1], name+' '+f'Threshold:{optimal_th:.2f}'))
        texts.append(plt.text(optimal_point[0], optimal_point[1], name))
        plt.plot(fpr, tpr, lw=3, label='{} (AUC={:.3f})'.format(name, auc(fpr, tpr)),color = colorname,linestyle=linestylename)
        plt.plot([0, 1], [0, 1], '--', lw=3, color = 'grey')
        plt.axis('square')
        plt.xlim([0, 1])
        plt.ylim([0, 1.05])
        plt.xlabel('False Positive Rate',fontsize=10)
        plt.ylabel('True Positive Rate',fontsize=10)
        plt.title('ROC Curve',fontsize=20)
        plt.legend(loc='lower right',fontsize=10)
    adjust_text(texts, 
            arrowprops=dict(
    			arrowstyle='->',#箭头样式 
    			lw= 2,#线宽
    			color='red')#箭头颜色
           )
    if save:
        plt.savefig('multi_models_roc.png')
    return plt

调用函数画图

names = ['Logistic Regression',
         'Naive Bayes',
         'Decision Tree',
         'Random Forest',
         'SVM',
         'Neural Network',
         'GBDT',
         'LightGBM',
         'XGBoost'
        ]

#这是各个模型的预测值返回列表
prob_results = [lg_y_prob,
                nb_y_prob,
                tree_y_prob,
                rf_y_prob,
                svm_y_prob,
                bp_y_prob,
                gbdt_y_prob,
                lgb_y_prob,
                xgb_y_prob
                ]
 
colors = ['crimson',
          'orange',
          'gold',
          'mediumseagreen',
          'steelblue', 
          'mediumpurple' ,
          'black',
          'silver',
          'navy'
         ]

linestyles = ['-', '--', '-.', ':', 'dotted', 'dashdot', '--', 'solid', 'dashed']
 
#ROC curves
train_roc_graph = multi_models_roc(names, prob_results, colors, linestyles,  Y_test_smo_tmo, save = True)
train_roc_graph.savefig('ROC_Train_all.png')

结果展示:

 

 

 

  • 9
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值