代码实现
1. 导包
[1]:import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split,cross_val_score,GridSearchCV
import graphviz
2. 加载数据
数据地址:
链接:https://pan.baidu.com/s/1WJFvu30qbF7vhzvvcOYPiw
提取码:jt9o
[2]:data = pd.read_csv("data.csv",index_col = "nameid")
data.head()
字段中文含义:
- name_id: 姓名
- profession: 职业,1-企业工作者,2-个体经营户,3-自由工作者,4-事业单位,5-体力劳动者
- education: 教育程度,1-博士及以上,2-硕士,3-本科,4-专科,5-高中及以下
- house_loan: 是否有房贷,1-有,0-没有
- car_loan:是否有车贷,1-有,0-没有
- married: 是否结婚,1-是,0-否
- child:是否有小孩,1-有,0-没有
- revenue:月收入
- approve:是否予以贷款,1-贷款,2-不贷款
3. 检查数据
[3]:data.info()
数据很干净,没有缺失值存在。共1000行数据
3. 切分测试集和训练集
先将特征和标签分开,然后将数据集中20%作为测试集,80%作为训练集。切分完后最好将每个数据集的索引恢复(决策树算法中不是必须)。
[4]:x = data.iloc[:,:-1]
y = data["approve"]
[5]:Xtrain,Xtest,Ytrain,Ytest = train_test_split(x,y,test_size = 0.2)
[6]:for i in [Xtrain,Xtest,Ytrain,Ytest]:
i.index = range(i.shape[0])
4. 建模
4.1 无参数建模
[7]:clf = DecisionTreeClassifier(random_state=10)
clf = clf.fit(Xtrain,Ytrain)
score = clf.score(Xtest,Ytest)
[8]:score
0.735
4.2 交叉验证
[9]:clf = DecisionTreeClassifier(random_state=10)
score = cross_val_score(clf,x,y,cv=10).mean()
score
0.7444790479047905
4.3 网格搜索
[10]:params = {
'criterion':['gini',"entropy"],
'max_depth':[*range(1,10,1)],
'min_samples_leaf':np.arange(3,10,1),
"min_impurity_decrease":[*np.linspace(0,0.5,50)]
}
[11]:clf = GridSearchCV(DecisionTreeClassifier(random_state=10),param_grid=params,cv=10)
clf = clf.fit(Xtrain,Ytrain)
[12]:clf.best_params_
{'criterion': 'gini',
'max_depth': 4,
'min_impurity_decrease': 0.0,
'min_samples_leaf': 6}
[13]:clf.best_score_
0.84625
4.5 模型可视化
[14]:clf = DecisionTreeClassifier(criterion="gini",max_depth=4,min_samples_leaf=6,min_impurity_decrease=0)
clf = clf.fit(Xtrain,Ytrain)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=data.columns[:-1],
class_names=["贷款","不贷款"],
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
[14]:graph = graphviz.Source(dot_data)
graph
参考资料:http://www.yearsmart.com/124.html