【机器学习系列】从导入数据到决策树可视化:一步步教你构建优化的机器学习模型

目录

一、导入数据

二、进行独热编码 

三、通过网格搜索进行参数调优,选择最优参数

四、利用训练好的参数建立决策树模型并进行交叉验证

五、将决策树可视化


一、导入数据

import pandas
data=pandas.read_csv('决策树.csv',
                      engine='python',encoding='utf8')

二、进行独热编码 

#需要进行独热处理的列
oneHotColumns = ['性别','父母鼓励']
from sklearn.preprocessing import OneHotEncoder
#新建独热编码器
oneHotEncoder = OneHotEncoder(drop='first')
#训练独热编码器,得到转换规则
oneHotEncoder.fit(data[oneHotColumns])
#转换数据
oneHotData = oneHotEncoder.transform(data[oneHotColumns])
from scipy.sparse import hstack
#将独热编码所得的数据,和父母收入、IQ两列合并在一起
x=hstack([oneHotData, data.父母收入.values.reshape(-1,1),data.IQ.values.reshape(-1,1)])
y=data['升学计划']

三、通过网格搜索进行参数调优,选择最优参数

from sklearn.model_selection import GridSearchCV

#全部训练全部测试,会导致过拟合
'''
    max_depth=None, 
    max_leaf_nodes=None, 
'''
dtModel = DecisionTreeClassifier()
dtModel.fit(x, y)
dtModel.score(x, y)

dtModel = DecisionTreeClassifier()

#网格搜索,寻找最优参数
paramGrid = dict(
    max_depth=[1, 2, 3, 4, 5],
    max_leaf_nodes=[3, 5, 6, 7, 8],
)
dtModel = DecisionTreeClassifier()
grid = GridSearchCV(
    dtModel, paramGrid, cv=10,
    return_train_score=True
)
grid = grid.fit(x, y)

print('最好的得分是: %f' % grid.best_score_)
print('最好的参数是:')
for key in grid.best_params_.keys():
    print('%s=%s'%(key, grid.best_params_[key]))

四、利用训练好的参数建立决策树模型并进行交叉验证

dtModel = DecisionTreeClassifier(
    criterion='gini', 
    max_depth=4, 
    max_leaf_nodes=7
)

cross_val_score(dtModel, x, y, cv=10).mean()

#训练决策树模型
dtModel = DecisionTreeClassifier(
    max_depth=4, 
    max_leaf_nodes=7
)
dtModel.fit(x, y)

五、将决策树可视化

#将决策树模型导出为 dot 文件
from sklearn.tree import export_graphviz
with open('data.dot', 'w') as f:
    f = export_graphviz(dtModel, out_file=f)
#绘图命令
#dot -Tpng data.dot -o tree.png

#导入pydot模块
import pydot_ng as pydot
#导入内存IO模块

#from sklearn.externals.six import StringIO
from six import StringIO

#把dot文件,写入StringIO中
dot_data = StringIO()
'''
    class_names: dtModel.classes_
    feature_names: oneHotEncoder.get_feature_names()    
'''
export_graphviz(
    dtModel, 
    out_file=dot_data,
    class_names=["不计划", "计划"],
    feature_names=[
        '男性', '父母鼓励', '父母收入', '智商'
    ],
    filled=True, rounded=True, 
    special_characters=True
) 
#从字符串中读入dot,生成graph对象
graph = pydot.graph_from_dot_data(
    dot_data.getvalue()
) 
#设置所有的节点的字体属性为 Microsoft YaHei
graph.get_node("node")[0].set_fontname(
    "Microsoft YaHei"
)
#将图形保存到 opt_tree.png 文件中
graph.write_png(
    'opt_tree.png'
)

r = data.pivot_table(
    index='父母鼓励',
    columns='升学计划',
    values='学生ID', 
    aggfunc='count'
)

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值