Spark-MLlib实例——决策树

57 篇文章 7 订阅
48 篇文章 4 订阅

Spark-MLlib实例——决策树

通俗来说,决策树分类的思想类似于找对象。现想象一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话:

[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. 女儿:多大年纪了?  
  2. 母亲:26。  
  3. 女儿:长的帅不帅?  
  4. 母亲:挺帅的。  
  5. 女儿:收入高不?  
  6. 母亲:不算很高,中等情况。  
  7. 女儿:是公务员不?  
  8. 母亲:是,在税务局上班呢。  
  9. 女儿:那好,我去见见。  




以上是决策的经典例子,用spark-mllib怎么实现训练与预测呢


1、首先准备测试数据集

训练数据集 Tree1

字段说明:

是否见面, 年龄  是否帅  收入(1 高 2 中等 0 少)  是否公务员

[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. 0,32 1 1 0  
  2. 0,25 1 2 0  
  3. 1,29 1 2 1  
  4. 1,24 1 1 0  
  5. 0,31 1 1 0  
  6. 1,35 1 2 1  
  7. 0,30 0 1 0  
  8. 0,31 1 1 0  
  9. 1,30 1 2 1  
  10. 1,21 1 1 0  
  11. 0,21 1 2 0  
  12. 1,21 1 2 1  
  13. 0,29 0 2 1  
  14. 0,29 1 0 1  
  15. 0,29 0 2 1  
  16. 1,30 1 1 0  


测试数据集 Tree2

[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. 0,32 1 2 0  
  2. 1,27 1 1 1  
  3. 1,29 1 1 0  
  4. 1,25 1 2 1  
  5. 0,23 0 2 1  

2、Spark-MLlib决策树应用代码

[java]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. import org.apache.log4j.{Level, Logger}  
  2. import org.apache.spark.mllib.feature.HashingTF  
  3. import org.apache.spark.mllib.linalg.Vectors  
  4. import org.apache.spark.mllib.regression.LabeledPoint  
  5. import org.apache.spark.mllib.tree.DecisionTree  
  6. import org.apache.spark.mllib.util.MLUtils  
  7. import org.apache.spark.{SparkConf, SparkContext}  
  8.   
  9. /** 
  10.   * 决策树分类 
  11.   */  
  12. object TreeDemo {  
  13.   
  14.   def main(args: Array[String]) {  
  15.   
  16.     val conf = new SparkConf().setAppName("DecisionTree").setMaster("local")  
  17.     val sc = new SparkContext(conf)  
  18.     Logger.getRootLogger.setLevel(Level.WARN)  
  19.   
  20.     //训练数据  
  21.     val data1 = sc.textFile("data/Tree1.txt")  
  22.   
  23.     //测试数据  
  24.     val data2 = sc.textFile("data/Tree2.txt")  
  25.   
  26.   
  27.     //转换成向量  
  28.     val tree1 = data1.map { line =>  
  29.       val parts = line.split(',')  
  30.       LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))  
  31.     }  
  32.   
  33.     val tree2 = data2.map { line =>  
  34.       val parts = line.split(',')  
  35.       LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))  
  36.     }  
  37.   
  38.     //赋值  
  39.     val (trainingData, testData) = (tree1, tree2)  
  40.   
  41.     //分类  
  42.     val numClasses = 2  
  43.     val categoricalFeaturesInfo = Map[Int, Int]()  
  44.     val impurity = "gini"  
  45.   
  46.     //最大深度  
  47.     val maxDepth = 5  
  48.     //最大分支  
  49.     val maxBins = 32  
  50.   
  51.     //模型训练  
  52.     val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,  
  53.       impurity, maxDepth, maxBins)  
  54.   
  55.     //模型预测  
  56.     val labelAndPreds = testData.map { point =>  
  57.       val prediction = model.predict(point.features)  
  58.       (point.label, prediction)  
  59.     }  
  60.   
  61.     //测试值与真实值对比  
  62.     val print_predict = labelAndPreds.take(15)  
  63.     println("label" + "\t" + "prediction")  
  64.     for (i <- 0 to print_predict.length - 1) {  
  65.       println(print_predict(i)._1 + "\t" + print_predict(i)._2)  
  66.     }  
  67.   
  68.     //树的错误率  
  69.     val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()  
  70.     println("Test Error = " + testErr)  
  71.     //打印树的判断值  
  72.     println("Learned classification tree model:\n" + model.toDebugString)  
  73.   
  74.   }  
  75.   
  76. }  


3、测试结果:

[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. label   prediction  
  2. 0.0 0.0  
  3. 1.0 1.0  
  4. 1.0 1.0  
  5. 1.0 1.0  
  6. 0.0 0.0  
  7. Test Error = 0.0  
  8. Learned classification tree model:  
可见真实值与预测值一致,Error为0


打印决策树的分支值,这里最大深度为 5 ,对应的树结构:

[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. Learned classification tree model:  
  2. DecisionTreeModel classifier of depth 4 with 11 nodes  
  3.   If (feature 1 <= 0.0)  
  4.    Predict: 0.0  
  5.   Else (feature 1 > 0.0)  
  6.    If (feature 3 <= 0.0)  
  7.     If (feature 0 <= 30.0)  
  8.      If (feature 2 <= 1.0)  
  9.       Predict: 1.0  
  10.      Else (feature 2 > 1.0)  
  11.       Predict: 0.0  
  12.     Else (feature 0 > 30.0)  
  13.      Predict: 0.0  
  14.    Else (feature 3 > 0.0)  
  15.     If (feature 2 <= 0.0)  
  16.      Predict: 0.0  
  17.     Else (feature 2 > 0.0)  
  18.      Predict: 1.0  
可见预测出的分界值与真实一致,准确率与决策树算法,参数设置及训练样本的选择覆盖有关!
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值