机器学习---决策树

spark机器学习—决策树

--------仅用于个人学习知识整理和R语言/python代码整理


1.前言

项目用到了spark环境下的决策树,并且使用r和python的ml下的函数,在回来的时候学习了python sklearn包。
ml下画图及找到上级节点并不方便(如有方便的方法请告知我!),加上一些自己写的寻找上级节点的code


2.R部分代码实现

1. 建模及存储模型部分

这部分ml包的使用并不难,中规中矩的使用包就可以

###rawdata---建模数据,需要是 x1 x2 x3 x4 x5 y 格式
model_tree<-ml_decision_tree_classifier(train_data,y~x1+x2+x3+x4+x5,max_depth=7,main_instances_per_node=400,min_info_gain=0,impurity="gini",seed=123,threshold=0)
ml_save(model_tree,"save---path")
###只是我不知道怎么画树图 是用python画的 以及这里不解释选项的意思 请直接转r的帮助

###有一个需求是 找出上级节点们
###r的模型输出的底表是在/stages/1_decision_tree_*/data 的 parquet下 parquet的处理不赘述
leafs<-dt %>% filter(leftChild==-1 & rightChild==-1)
msg_raw<-data.frame()
dfs<-data.frame()

getparent<-function(node,current){
  left<-dt %>% filter(leftChild==node$id)
  right<-dt %>% filter(rightChild==node$id)
  
  if(left %>% count()>0){
    msgtt<-left %>% 
      mutate(flag="left",node=current) %>%
      select(node_flag,parent=id,featureIndex,leftCategoriesOrThreshold,numCategories)
    node<<-left
  } else {
    msgtt<-right %>% 
      mutate(flag="right",node=current) %>%
      select(node_flag,parent=id,featureIndex,leftCategoriesOrThreshold,numCategories)
    node<<-right
  }
  msg_raw<<-rbind(msg_raw,msgtt)
  return(msg_raw)
}

for (ii in leafs$id){
  node<-dt %>% filter(id==ii)
  
  while(node$id>0){
    xx<-getparent(node,ii)
  }
  dfs<-rbind(dfs,xx)
  msg_raw<-data.frame()
}

如果有更好的思路和方法,和我代码上的改进欢迎和我交流!


3.python部分代码实现

在此感谢本站博主写的一篇汇总,并且附上链接:https://blog.csdn.net/littlely_ll/article/details/78151964?locationNum=10&fps=1

主要解释了各个参数的含义,这里为了画树图,并且数据不大,没有用pyspark,而是使用了sklearn包中的决策树

from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.externals import joblib
import pydotplus

###dataset是用spark处理的
data_pivot_cus = spark.read.parquet("hdfs://spark-namenode1:9000/user/data_pivot/")

pd_data=data_pivot_cus.toPandas()

###变量列表
variable_list=data_pivot_cus.columns
variable_list.remove('md5')

#####交叉验证
# from sklearn.model_selection import GridSearchCV
# import numpy as np

# entropy_thresholds = np.linspace(0, 1, 100)
# gini_thresholds = np.linspace(0, 0.2, 100)
# #设置参数矩阵:
# param_grid = [{'criterion': ['entropy'], 'min_impurity_decrease': entropy_thresholds},
#               {'criterion': ['gini'], 'min_impurity_decrease': gini_thresholds},
#               {'max_depth': np.arange(4,8)},
#               {'min_samples_split': np.arange(100,300,50)}]
# clf = GridSearchCV(DecisionTreeClassifier(), param_grid, cv=5)
# clf.fit(feature,target)
# print("best param:{0}\nbest score:{1}".format(clf.best_params_, clf.best_score_))

###决策树建模
clf=DecisionTreeClassifier(max_depth=7,criterion='entropy',random_state=123456,\
                            splitter='best',min_samples_split=100,min_samples_leaf=200)

feature=pd_data[variable_list]
target=pd_data['target']
model=clf.fit(feature,target)
joblib.dump(model,'DecisionTree_4.model')
model=joblib.load('DecisionTree_4.model')
dot_data = tree.export_graphviz(model, out_file=None, 
                         feature_names=variable_list,  
                         filled=True, rounded=True,  
                         special_characters=True)  
 ####服务器可能有中文乱码问题 需要安装中文字体包
graph = pydotplus.graph_from_dot_data(dot_data.replace('helvetica','"Microsoft YaHei"'))   

graph.write_png('decisiontree_9.png')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值