决策树算法,ID3,C4.5,CART原理,SparkMllib的鸢尾花实战

决策树算法原理

  • 什么是决策树

    • 现实生活中的树

      • 树根->树干->树枝->树叶
    • 数据结构中的树

      • 树根结点
      • 分支结点
      • 叶子结点
    • 机器学习中的树

      • 分支结点
      • 叶子结点
      • 数据集中的特征是分支节点、数据集中的类别标签列是叶子节点。
    • 决策树的关键步骤是分裂属性。

      所谓分裂属性就是在某个节点处按照某一特征属性的不同划分构造不同的分支,其目标是让各个分裂子集尽可能地“纯”。尽可能“纯”就是尽量让一个分裂子集中待分类项属于同一类别。

      而判断“纯”的方法不同引出了我们的ID3算法,C4.5算法以及CART算法

  • 基于规则建树

    • 规则是什么?
      • 业务专家给出的规则,这些规则当成是训练算法所需要的特征。
    • 定性简历决策树和定量的简历决策树
      • 为什么需要从定量角度来分析呢?
        • 答:这样会更精确的分析用户的特征信息,给出销售人员更准确的数据信息。
  • 基于模型的建树

    • 总结构建决策树的三要素

      • 特征的选择:信息熵、信息增益、信息增益率、基尼系数
      • 决策树的生成:ID3,C4.5,Cart树
      • 决策树的剪枝:先剪枝、后剪枝
    • 特征选择

      • 熵:物理学上度量能量分布不确定性的量。

      • 信息熵、香农熵:为了消除信息不确定性,代表随机变量的复杂度

        • 信息熵越大、信息的不确定性越大,信息的确定性越小,信息的纯度越低。
        • 信息熵越小、信息的不确定性越小,信息的确定性越大,信息的纯度越大。

        H X = − ∑ i = 1 n P ( x i ) l o g P ( x i ) HX=-\sum_{i=1}^{n}P_{(x_i)}log^{P_{(x_i)}} HX=i=1nP(xi)logP(xi)

      • 条件熵

        • 在某一个条件下,随机变量的复杂度

        H ( Y ∣ X ) = ∑ x ∈ X P ( x ) l o g P ( Y ∣ X = x ) H(Y|X)=\sum_{x \in X}P(x)log^{P(Y|X=x)} H(YX)=xXP(x)logP(YX=x)

      • 信息增益

        • 信息增益=信息熵-条件熵
        • 信息增益代表了在一个条件下,信息复杂度(不确定性)减少的程度
      • 信息增益率

        • 用信息增益率来选择属性,克服用信息增益选择时候偏向选择取值多属性不足。
          G a i n R a t e ( D , A ) = G a i n ( D , A ) H ( A ) GainRate(D,A) =\frac {Gain(D,A)} {H(A)}\\ GainRate(D,A)=H(A)Gain(D,A)

        • D为数据集、A为一个特征、其中H(A)为A的熵

      • 参考资料

  • ID3算法、C4.5算法

    • ID3

      • 在我们的ID3算法中,我们采取信息增益这个量来作为纯度的度量。我们选取使得信息增益最大的特征进行分裂!
      • 我们从上面求解信息增益的公式中,其实可以看出,信息增益准则其实是对可取值数目较多的属性有所偏好!
      • **因为每一个样本的编号都是不同的(由于编号独特唯一,条件熵为0了,每一个结点中只有一类,纯度非常高啊),也就是说,来了一个预测样本,你只要告诉我编号,其它特征就没有用了,**这样生成的决策树显然不具有泛化能力。
    • C4.5

      • 使用了信息增益率这个量来作为纯度的度量。我们选取使得信息增益率最大的特征进行分裂!

      • **我们一开始分析到,**信息增益准则其实是对可取值数目较多的属性有所偏好!(比如上面提到的编号,可能取值是实例个数,最多了,分的类别越多,分到每一个子结点,子结点的纯度也就越可能大,因为数量少了嘛,可能在一个类的可能性就最大)。

      • 但是在前面分析了,并不是很好,所以我们需要除以一个属性的固定值(这个属性的熵),这个值要求随着分成的类别数越大而越小。于是让它做了分母。这样可以避免信息增益的缺点。

      • 那么信息增益率就是完美无瑕的吗?

        当然不是,有了这个分母之后,我们可以看到增益率准则其实对可取类别数目较少的特征有所偏好!毕竟分母越小,整体越大。

      • 于是C4.5算法不直接选择增益率最大的候选划分属性,候选划分属性中找出信息增益高于平均水平的属性(这样保证了大部分好的的特征),再从中选择增益率最高的(又保证了不会出现编号特征这种极端的情况

    • 深入浅出理解决策树算法(二)-ID3算法与C4.5算法

  • Cart树算法

    • 简称:分类和回归树—和ID3、C4.5区别

      • 区别和联系:Cart树是二叉树,ID3和C4.5多棵决策树

      • Cart树在分类上使用的是集合Gini系数

        • 基尼系数如下:

        G i n i = ∑ i = 1 m P i ( 1 − P i ) = 1 − ∑ i = 1 m P i 2 Gini=\sum_{i=1}^{m}P_i(1-P_i)=1-\sum_{i=1}^{m}P_i^2 Gini=i=1mPi(1Pi)=1i=1mPi2

        • GINI指数
        • 对每个特征 A,对它的所有可能取值 a,将数据集分为 A=a D1,和 A!=a D2 两个子集,计算集合 D 的基尼指数:

        G I N I ( D , A ) = ∣ D 1 ∣ ∣ D ∣ G i n i ( D 1 ) + ∣ D 2 ∣ ∣ D ∣ G i n i ( D 2 ) GINI(D,A)=\frac {|D_1|}{|D|}Gini(D_1)+\frac {|D_2|}{|D|}Gini(D_2) GINI(D,A)=DD1Gini(D1)+DD2Gini(D2)

      • 回归问题上MSE(mean square error-sum)
        1 m ∑ i = 1 m ( y i − y i ‾ ) 2 \frac 1 m \sum_{i=1}^{m}(y_i-\overline{y_i})^2 m1i=1m(yiyi)2

    • 回归树

      • 选择最优的切分点和切分变量
      • 用选定的切分点和切分变量对原来的数据区域进行划分
      • 递归调用算法生成多区域的二叉回归树
    • 是GBDT、XGBOOST算法的基础

    • 分类树

      • 1对每个特征 A,对它的所有可能取值 a,将数据集分为 A=a,和 A!=a 两个子集,计算集合 D 的基尼指数.

      G I N I ( D , A ) = ∣ D 1 ∣ ∣ D ∣ G i n i ( D 1 ) + ∣ D 2 ∣ ∣ D ∣ G i n i ( D 2 ) GINI(D,A)=\frac {|D_1|}{|D|}Gini(D_1)+\frac {|D_2|}{|D|}Gini(D_2) GINI(D,A)=DD1Gini(D1)+DD2Gini(D2)

      • 2遍历所有的特征 A,计算其所有可能取值 a 的基尼指数,选择 D 的基尼指数最小值对应的特征及切分点作为最优的划分,将数据分为两个子集。
      • 3对上述两个子节点递归调用步骤(1)(2), 直到满足停止条件。
      • 4生成 CART 决策树。
  • 决策树、ID3、C4.5、Cart回归树、Cart分类树的剪枝问题还没有分析,每天分析下。

  • SparkMllib完成建模分析实践

    • 鸢尾花iris实战,使用rdd方式

      import org.apache.spark.mllib.linalg.Vectors
      import org.apache.spark.mllib.regression
      import org.apache.spark.mllib.regression.LabeledPoint
      import org.apache.spark.mllib.tree.DecisionTree
      import org.apache.spark.mllib.tree.model.DecisionTreeModel
      import org.apache.spark.rdd.RDD
      import org.apache.spark.{SparkConf, SparkContext}
      
      object SparkMllibIris1 {
        def main(args: Array[String]): Unit = {
          // 1.准本环境
          val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkMllibIris1Rdd")
          val sc = new SparkContext(conf)
          // 2.读取数据
          val path = "iris.csv"
          val rdd: RDD[String] = sc.textFile(path)
          //    rdd.foreach(println)
          //    6.2,3.4,5.4,2.3,Iris-virginica
          //    5.9,3.0,5.1,1.8,Iris-virginica
          // 3,特征工程
          // 3-1得到LabelPoint rdd中很多好用的API都没有,需要使用传统的方式进行特征提取,转换,选择
          var rddLp: RDD[LabeledPoint] = rdd.map(
            x => {
              val strings: Array[String] = x.split(",")
              regression.LabeledPoint(
                strings(4) match {
                  case "Iris-setosa" => 0.0
                  case "Iris-versicolor" => 1.0
                  case "Iris-virginica" => 2.0
                }
              ,
                Vectors.dense(
                  strings(0).toDouble,
                  strings(1).toDouble,
                  strings(2).toDouble,
                  strings(3).toDouble))
            }
          )
      //    rddLp.foreach(println)
      //    (1.0,[6.0,2.9,4.5,1.5])
      //    (0.0,[5.1,3.5,1.4,0.2])
          // 4. 分割数据集为训练集和测试集
          val Array(trainData,testData): Array[RDD[LabeledPoint]] = rddLp.randomSplit(Array(0.8,0.2))
          // 5. 构建模型
          val decisonModel: DecisionTreeModel = DecisionTree.trainClassifier(trainData,3, Map[Int, Int](),"gini",8,16)
          // 6. 得到测试集预测的结果,跟原有的标签共同构成一个元组,方便后面进行相应的计算
          // 而DataFrame中有相应的函数,可以帮助我们进行校验,RDD没有这方面的待遇,需要自己写相应的方法
          val result: RDD[(Double, Double)] = testData.map(
            x=> {
              val pre: Double = decisonModel.predict(x.features)
              (x.label,pre)
            }
          )
          val acc: Double = result.filter(x=>x._1==x._2).count().toDouble /result.count()
          println(acc)
          println("error", (1-acc))
      //    0.9642857142857143
      //    (error,0.0357142857142857)
        }
      }
      
    • 鸢尾花iris实战-DataFrame方式实现

      import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
      import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
      import org.apache.spark.ml.feature.{IndexToString, StringIndexer, StringIndexerModel, VectorAssembler}
      import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
      
      object SparkMlIris2 {
        def main(args: Array[String]): Unit = {
          //    * 1-准备环境
          val sparkSession: SparkSession = SparkSession.builder().master("local[*]").appName("SparkMllibIris2").getOrCreate()
          //    * 2-准备数据
          // 2-1 通过CSV的方式来读取数据,官网有读取的方式 http://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html
          var path = "irisHeader.csv"
          // 注意要添加 .option("inferSchema", "true"),否则df schema 都是String类型的
          val df: DataFrame = sparkSession.read.format("csv").option("inferSchema", "true").option("header","true").option("sep",",").load(path)
      //    df.printSchema()
      //    root
      //    |-- sepal_length: double (nullable = true)
      //    |-- sepal_width: double (nullable = true)
      //    |-- petal_length: double (nullable = true)
      //    |-- petal_width: double (nullable = true)
      //    |-- class: string (nullable = true)
          //df.show(false)
          //    +------------+-----------+------------+-----------+-----------+
      //    |sepal_length|sepal_width|petal_length|petal_width|class      |
      //    +------------+-----------+------------+-----------+-----------+
      //    |5.1         |3.5        |1.4         |0.2        |Iris-setosa|
          //* 4-特征工程
          //4-1将4个特征整合为一个特征向量
          val assembler: VectorAssembler = new VectorAssembler().setInputCols(Array("sepal_length","sepal_width","petal_length","petal_width")).setOutputCol("features")
          val assmblerDf: DataFrame = assembler.transform(df)
          assmblerDf.show(false)
          //4-2将类别型class转变为数值型
          val stringIndex: StringIndexer = new StringIndexer().setInputCol("class").setOutputCol("label")
          val stingIndexModel: StringIndexerModel = stringIndex.fit(assmblerDf)
          val indexDf: DataFrame = stingIndexModel.transform(assmblerDf)
      //    indexDf.show(false)
      //    +------------+-----------+------------+-----------+-----------+-----------------+-----+
      //    |sepal_length|sepal_width|petal_length|petal_width|class      |features         |label|
      //      +------------+-----------+------------+-----------+-----------+-----------------+-----+
      //    |5.1         |3.5        |1.4         |0.2        |Iris-setosa|[5.1,3.5,1.4,0.2]|0.0  |
      //      |4.9         |3.0        |1.4         |0.2        |Iris-setosa|[4.9,3.0,1.4,0.2]|0.0  |
          //4-3将数据切分成两部分,分别为训练数据集和测试数据集
          val Array(trainData,testData): Array[Dataset[Row]] = indexDf.randomSplit(Array(0.8,0.2))
          //    * 5-准备计算法,设置特征列和标签列
          val classifier: DecisionTreeClassifier = new DecisionTreeClassifier().setFeaturesCol("features").setMaxBins(16).setImpurity("gini").setSeed(10)
          val dtcModel: DecisionTreeClassificationModel = classifier.fit(trainData)
          //    * 6-完成建模分析
          val trainPre: DataFrame = dtcModel.transform(trainData)
          //    * 7-预测分析
          val testPre: DataFrame = dtcModel.transform(testData)
          //    * 8-模型的校验或保存
          //val savePath = "E:\\ml\\workspace\\SparkMllibBase\\sparkmllib_part2\\DescitionTree\\model"
          //dtcModel.save(savePath)
      //    trainPre.show(false)
      //    +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------+----------+
      //    |sepal_length|sepal_width|petal_length|petal_width|class          |features         |label|rawPrediction |probability  |prediction|
      //      +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------+----------+
      //    |4.3         |3.0        |1.1         |0.1        |Iris-setosa    |[4.3,3.0,1.1,0.1]|0.0  |[47.0,0.0,0.0]|[1.0,0.0,0.0]|0.0       |
      //      |4.4         |2.9        |1.4         |0.2        |Iris-setosa    |[4.4,2.9,1.4,0.2]|0.0  |[47.0,0.0,0.0]|[1.0,0.0,0.0]|0.0       |
      //    testPre.show(false)
      //    +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------+----------+
      //    |sepal_length|sepal_width|petal_length|petal_width|class          |features         |label|rawPrediction |probability  |prediction|
      //      +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------+----------+
      //    |4.6         |3.2        |1.4         |0.2        |Iris-setosa    |[4.6,3.2,1.4,0.2]|0.0  |[47.0,0.0,0.0]|[1.0,0.0,0.0]|0.0       |
      //      |4.8         |3.4        |1.9         |0.2        |Iris-setosa    |[4.8,3.4,1.9,0.2]|0.0  |[47.0,0.0,0.0]|[1.0,0.0,0.0]|0.0       |
      //      |5.0         |2.0        |3.5         |1.0        |Iris-versicolor|[5.0,2.0,3.5,1.0]|1.0  |[0.0,33.0,0.0]|[0.0,1.0,0.0]|1.0       |
          val acc: Double = new MulticlassClassificationEvaluator().setMetricName("accuracy").evaluate(testPre)
          println("acc is ", acc)
          println("err is", (1-acc))
          // 9-将测试集预测的索引类别标签转变回字符串类型的
          val indexToString: IndexToString = new IndexToString().setInputCol("prediction").setOutputCol("preStringLabel").setLabels(stingIndexModel.labels)
          val result: DataFrame = indexToString.transform(testPre)
      //    result.show(false)
      //    +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------------------------------------+----------+---------------+
      //    |sepal_length|sepal_width|petal_length|petal_width|class          |features         |label|rawPrediction |probability                                |prediction|preStringLabel |
      //      +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------------------------------------+----------+---------------+
      //    |4.6         |3.6        |1.0         |0.2        |Iris-setosa    |[4.6,3.6,1.0,0.2]|0.0  |[38.0,0.0,0.0]|[1.0,0.0,0.0]                              |0.0       |Iris-setosa    |
      //      |4.8         |3.4        |1.6         |0.2        |Iris-setosa    |[4.8,3.4,1.6,0.2]|0.0  |[38.0,0.0,0.0]|[1.0,0.0,0.0]                              |0.0       |Iris-setosa    |
        }
      }
      
      
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值