gbdt原文_gbdt+lr代码

import numpy as np

np.random.seed(10)

import matplotlib.pyplot as plt

from sklearn.datasets import make_classification

from sklearn.linear_model import LogisticRegression

from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier,

GradientBoostingClassifier)

from sklearn.preprocessing import OneHotEncoder

from sklearn.model_selection import train_test_split

from sklearn.metrics import roc_curve

from sklearn.pipeline import make_pipeline

n_estimator = 10

X, y = make_classification(n_samples=80000)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)

X_train, X_train_lr, y_train, y_train_lr = train_test_split(X_train,

y_train,

test_size=0.5)

rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator)

rf_enc = OneHotEncoder()

rf_lm = LogisticRegression()

rf.fit(X_train, y_train)

rf_enc.fit(rf.apply(X_train))

rf_lm.fit(rf_enc.transform(rf.apply(X_train_lr)), y_train_lr)

y_pred_rf_lm = rf_lm.predict_proba(rf_enc.transform(rf.apply(X_test)))[:, 1]

fpr_rf_lm, tpr_rf_lm, _ = roc_curve(y_test, y_pred_rf_lm)

grd = GradientBoostingClassifier(n_estimators=n_estimator)

grd_enc = OneHotEncoder()

grd_lm = LogisticRegression()

grd.fit(X_train, y_train) #GBDT建模

grd_enc.fit(grd.apply(X_train)[:, :, 0])

grd_lm.fit(grd_enc.transform(grd.apply(X_train_lr)[:, :, 0]), y_train_lr)

y_pred_grd_lm = grd_lm.predict_proba(

grd_enc.transform(grd.apply(X_test)[:, :, 0]))[:, 1]

fpr_grd_lm, tpr_grd_lm, _ = roc_curve(y_test, y_pred_grd_lm)

y_pred_grd = grd.predict_proba(X_test)[:, 1]

fpr_grd, tpr_grd, _ = roc_curve(y_test, y_pred_grd)

y_pred_rf = rf.predict_proba(X_test)[:, 1]

fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_rf)

plt.figure(2)

plt.xlim(0, 0.2)

plt.ylim(0.8, 1)

plt.plot([0, 1], [0, 1], ‘k--‘)

plt.plot(fpr_rf, tpr_rf, label=‘RF‘)

plt.plot(fpr_rf_lm, tpr_rf_lm, label=‘RF + LR‘)

plt.plot(fpr_grd, tpr_grd, label=‘GBT‘)

plt.plot(fpr_grd_lm, tpr_grd_lm, label=‘GBT + LR‘)

plt.xlabel(‘False positive rate‘)

plt.ylabel(‘True positive rate‘)

plt.title(‘ROC curve (zoomed in at top left)‘)

plt.legend(loc=‘best‘)

plt.show()

145cfe02e2917d0615b7f3cc6eae61fa.png

原文:http://blog.51cto.com/yixianwei/2156791

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值