catboost测试,ROC

用小数据集测试catboost,并画出ROC曲线
数据集用CSV

首先import需要的库

import pandas as pd
import numpy as np
from  catboost  import  CatBoostClassifier,CatBoostRegressor,Pool
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

pandas读入csv
去掉标题的第一行

obj=pd.read_csv('file.csv',header=1)

划分training data和test_data
这里简单地把前N行(N自己定数字)作为训练集,最后一列去掉,倒数第二列是label
注意这里的index是从0开始,to_index是end的下一个index

#0~N-1行,需要指定为0:N
train_data = obj.iloc[0:N, 1:-2]
train_label = obj.iloc[0:N, -2]

test_data = obj.iloc[N:, 1:-2]
test_label = obj.iloc[N:, -2]

由于数据量很少,很容易过拟合,减小迭代次数

model=CatBoostClassifier(iterations=2,depth=2,learning_rate=0.5,loss_function='Logloss', logging_level='Verbose')

训练

model.fit(train_data,train_label)

根据阈值来确定predict label

threshold=0.5

predict_label=np.zeros(prob.shape)
predict_label[prob>threshold]=1

计算准确率,SE,SP

cm = confusion_matrix(train_label,predict_label)
se = cm[0,0]/(cm[0,0]+cm[0,1])
print('SE : ', se)

sp = cm[1,1]/(cm[1,0]+cm[1,1])
print('SP : ', sp)

acc=(cm[0,0]+cm[1,1])/train_label.size
print ('Accuracy : ', acc)

计算AUC

#training data
fpr,tpr,threshold = roc_curve(train_label, prob,pos_label=1)
roc_auc = auc(fpr,tpr)

#test data
prob_test = model.predict_proba(test_data)
prob_test = prob_test[:,1]
fpr_test,tpr_test,threshold_test = roc_curve(test_label, prob_test,pos_label=1)
roc_auc_test = auc(fpr_test,tpr_test)

training data的ROC

plt.figure()
lw = 2
plt.figure(figsize=(6,6))
plt.plot(fpr, tpr, color='red',
         lw=lw, label='ACC = %0.2f, SE = %0.2f, SP = %0.2f, AUC = %0.2f' % (acc,se,sp,roc_auc)) ###假正率为横坐标,真正率为纵坐标做曲线
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve')
plt.legend(loc="lower right")
plt.grid()
plt.show()

test data的ROC

plt.figure()
lw = 2
plt.figure(figsize=(10,10))
plt.plot(fpr_test, tpr_test, color='red',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc_test) ###假正率为横坐标,真正率为纵坐标做曲线
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve')
plt.legend(loc="lower right")
plt.show()
  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

蓝羽飞鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值