下面对GBDT做简单的实现。
(1)加载需要的python模块
import pandas as pd
import numpy as np
import pydotplus
from sklearn.ensemble import GradientBoostingClassifier
from sklearn import metrics
from sklearn.externals.six import StringIO
from sklearn import tree
(2)数据集信息
数据集如下,其中 train_modified.csv下载地址为: http://files.cnblogs.com/files/pinard/train_modified.zip
train = pd.read_csv('train_modified.csv')
target='Disbursed'
IDcol = 'ID'
x_columns = [x for x in train.columns if x not in [target, IDcol]]
X = train[x_columns]
y = train['Disbursed']
print(train[:10])
特征共计49个,详细信息如下:
['Existing_EMI',
'Loan_Amount_Applied',
'Loan_Tenure_Applied',
'Monthly_Income',
'Var4',
'Var5',
'Age',
'EMI_Loan_Submitted_Missing',
'Interest_Rate_Missing',
'Loan_Amount_Submitted_Missing',
'Loan_Tenure_Submitted_Missing',
'Processing_Fee_Missing',
'Device_Type_0',
'Device_Type_1',
'Filled_Form_0',
'Filled_Form_1',
'Gender_0',
'Gender_1',
'Var1_0',
'Var1_1',
'Var1_2',
'Var1_3',
'Var1_4',
'Var1_5',
'Var1_6',
'Var1_7',
'Var1_8',
'Var1_9',
'Var1_10',
'Var1_11',
'Var1_12',
'Var1_13',
'Var1_14',
'Var1_15',
'Var1_16',
'Var1_17',
'Var1_18',
'Var2_0',
'Var2_1',
'Var2_2',
'Var2_3',
'Var2_4',
'Var2_5',
'Var2_6',
'Mobile_Verified_0',
'Mobile_Verified_1',
'Source_0',
'Source_1',
'Source_2']
(3)训练模型
gbdt_model = GradientBoostingClassifier(learning_rate=0.005,
n_estimators=1200,
max_depth=7,
min_samples_leaf =60,
min_samples_split =1200,
max_features=9,
subsample=0.7,
random_state=10)
gbdt_model.fit(X,y)
y_pred = gbdt_model.predict(X)
y_pred_prob = gbdt_model.predict_proba(X)[:,1]
print("Accuracy : %.4g" % metrics.accuracy_score(y.values, y_pred))
print("AUC Score (Train): %f" % metrics.roc_auc_score(y, y_pred_prob))
# Accuracy : 0.984
# AUC Score (Train): 0.908232
模型的准确率和AUC指标分别为 0.984 和 0.9082,效果还不错。感兴趣的小伙伴可以尝试调整参数。
(4)特征重要度情况
以下数据列举了每个特征所对应的特征重要度,数值越大,重要度越高。
score_feature = gbm.feature_importances_
print(score_feature)
"""
[0.15929728 0.04617883 0.02608923 0.25709074 0.03488485 0.14528308
0.06911841 0.00267712 0.00290675 0.00819478 0.00553175 0.00428709
0.00861779 0.00927034 0.00676201 0.00652589 0.00667953 0.0060258
0.00059827 0.00570858 0.04263685 0.00957521 0. 0.00037965
0.00097049 0. 0.01065202 0.00118078 0.00832415 0.01314289
0.00166272 0.00267764 0. 0.01484987 0. 0.
0. 0. 0.00494884 0.00951563 0.00307255 0.00451455
0.00069745 0. 0.00436676 0.00482608 0.03472633 0.00907289
0.00647855]
"""
(5)绘制树状图
dot_data = StringIO()
tree.export_graphviz(gbdt_model.estimators_[0,0],
out_file = dot_data,
node_ids=True,
filled=True,
rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("gbdt.pdf")
结果如下: