CH5_决策树(ID3)及其spark实现

CH5_决策树(ID3)及其spark实现

1.决策树定义

决策树是一种基本的分类与回归方法,分类决策树模型是一种描述对实例进行分类的树状结构。决策树由结点(Node)和有向边(directed edge)组成。结点有两种类型:内部结点和叶结点。内部结点表示一个特征或属性,叶结点表示一个类。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QlIHY9Lt-1589348657852)(E3C88258B82E417C99458428C39F155B)]

决策树可以看成是一个if-then规则的集合,也可以表示给定特征条件下类的条件概率分布。

决策树的学习目标是根据给定的训练数据集构建一个决策树模型,使它能够对实例进行正确的分类。


2.决策树特征选择
2.1 信息增益

熵是表示随机变量不确定性的度量;设X是一个取有限个值的离散随机变量,其概率分布为:

P ( X = x i ) = p i , i = 1 , 2 , ⋅ ⋅ ⋅ n P(X=x_i)=p_i , i = 1,2,···n P(X=xi)=pi,i=1,2,n

随机变量的熵定义为:

H ( X ) = − ∑ p i l o g p i H(X) = -\sum{p_ilogp_i} H(X)=pilogpi

  • 信息增益

信息增益:特征A对训练数据集D的信息增益g(D,A),定义为集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差,即:

g ( D , A ) = H ( D ) − H ( D ∣ A ) g(D,A) = H(D) - H(D|A) g(D,A)=H(D)H(DA)

根据信息增益准则选择特征的方法是:选择信息增益最大的特征

2.2 信息增益比

特 征 A 特 征 A 对 训 练 数 据 集 D 的 信 息 增 益 比 g R ( D , A ) 定 义 为 信 息 增 益 g ( D , A ) 与 训 练 集 D 关 于 特 征 A 的 值 的 熵 H A ( D ) 之 比 , 即 : g R ( D , A ) = g ( D , A ) H A ( D ) 其 中 , H A ( D ) = − ∑ ∣ D i ∣ ∣ D ∣ l o g 2 ∣ D i ∣ ∣ D ∣ , n 是 特 征 A 的 取 值 个 数 特征A特征A对训练数据集D的信息增益比g_R(D,A)定义为信息增益g(D,A) \\ 与训练集D关于特征A的值的熵H_A(D)之比,即:\\ g_R(D,A) = \frac{g(D,A)}{H_A(D)} \\ 其中,H_A(D) = -\sum \frac{|D_i|}{|D|}log_2\frac{|D_i|}{|D|},n是特征A的取值个数 AADgR(D,Ag(D,A)DAHA(D)gR(D,A)=HA(D)g(D,A)HA(D)=DDilog2DDi,nA

根据信息增益准则选择特征的方法是:选择信息增益比最大的特征

2.3 基尼系数

分 类 问 题 中 , 假 如 有 K 个 类 , 样 本 点 属 于 第 k 类 的 概 率 为 P k , 则 概 率 分 布 的 基 尼 系 数 定 义 为 : 分类问题中,假如有K个类,样本点属于第k类的概率为P_k,则概率分布的基尼系数定义为: KkPk,:
G i n i ( p ) = ∑ k = 1 K P k ( 1 − P k ) = 1 − ∑ k = 1 K P k 2 Gini(p) = \sum_{k=1}^{K} P_k(1-P_k) = 1-\sum_{k=1}^{K} P_k^2 Gini(p)=k=1KPk(1Pk)=1k=1KPk2
对于二分类问题,若样本点属于第一个类的概率是p,则概率分布的基尼系数为: G i n i ( p ) = 2 p ( 1 − p ) Gini(p) = 2p(1-p) Gini(p)=2p(1p)

对于给定的样本集合D,其基尼系数为:

G i n i ( p ) = 1 − ∑ ( ∣ C k ∣ ∣ D ∣ ) 2 这 里 , C k 是 D 中 属 于 第 k 类 的 样 本 子 集 , K 是 类 的 个 数 Gini(p) = 1-\sum (\frac{|C_k|}{|D|})^2 \\ 这里,C_k是D中属于第k类的样本子集,K是类的个数 Gini(p)=1(DCk)2CkDkK


3.几种常用的决策树算法
3.1 ID3算法:应用信息增益准则选择特征
3.2 C4.5算法:应用信息增益比准则选择特征
3.3 CART算法(分类与回归树):分类树用基尼系数选择最优特征

4. 决策树(ID3)模型
 package CH5_DecisionTree

import org.apache.spark.sql.functions.{col, count, log2, sum}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}

/**
  * Created by WZZC on 2019/12/6
  **/
case class DecisionTreeModel(data: DataFrame,
                             labelColName: String,
                             threshold: Double = 1e-2) {

  private val spark = data.sparkSession
  import spark.implicits._

  var node: DtreeNode = _
  var search: List[String] = _

  /**
    *  获取实例数最大的类Ck
    *
    * @param dataFrame
    * @return
    */
  def maxCountLabel(dataFrame: DataFrame) = {
    dataFrame
      .select(labelColName)
      .groupBy(labelColName)
      .agg(count(labelColName) as "ck")
      .collect()
      .map(row => (row.getString(0), row.getLong(1)))
      .maxBy(_._2)
      ._1
  }

  /**
    *最优特征选择(ID3)
    *
    * @param df dataframe
    * @param labelCol ck类S Colname
    * @param ftSchemas 特征集合
    * @return
    */
  def optimalFeatureSel(df: DataFrame,
                        labelCol: String,
                        ftSchemas: Array[String]) = {

    // 数据格式转换,行转列
    val ftsCount = df
      .flatMap(row => {
        val label = row.getAs[String](labelColName)
        (0 until row.length).map(i => {
          (label, ftSchemas(i), row.getString(i))
        })
      })
      .toDF("label", "ftsName", "ftsValue")
      .groupBy("label", "ftsName", "ftsValue")
      .agg(count("label") as "lcount")
      .repartition($"ftsName")
      .cache()

    //impiricalEntropy 经验熵
    val preProbdf = ftsCount
      .where($"ftsName" === labelCol)
      .cache()

    val dfcount: Double = preProbdf
      .agg(sum("lcount") as "lsum")
      .head()
      .getAs[Long]("lsum")
      .toDouble

    val impiricalEntropy: Double = preProbdf
      .withColumn("pi", $"lcount" / dfcount)
      .withColumn("hd", log2($"pi") * (-$"pi"))
      .agg(sum($"hd") as "hd")
      .collect()
      .head
      .getAs[Double]("hd")

    //经验条件熵
    val ftsValueCount: DataFrame = ftsCount
      .filter($"ftsName" =!= labelCol)
      .groupBy($"ftsName", $"ftsValue")
      .agg(sum("lcount") as "lsum")

    val cens: DataFrame = ftsCount
      .join(ftsValueCount, Seq("ftsName", "ftsValue"))
      .orderBy($"ftsName", $"label")
      .withColumn("cpi", $"lcount".cast(DoubleType) / $"lsum")
      .withColumn("pi", $"lsum" / dfcount)
      .withColumn("gda", -$"cpi" * log2($"cpi") * $"pi")
      .groupBy($"ftsName")
      .sum("gda")

    // 信息增益 ->最大信息增益
    val (ftsName, maxGda) = cens
      .withColumn("xxzy", -$"sum(gda)" + impiricalEntropy)
      .collect()
      .map(row => {
        val ftsName = row.getString(0)
        val maxgda = row.getDouble(2)
        (ftsName, maxgda)
      })
      .maxBy(_._2)

    val ftsLabels: Array[String] = ftsValueCount
      .where($"ftsName" === ftsName)
      .select($"ftsValue")
      .collect()
      .map(_.getString(0))

    ftsCount.unpersist()
    preProbdf.unpersist()

    (ftsName, maxGda, ftsLabels)

  }

  /**
    * 数据按照特征的值划分
    *
    * @param df dataframe
    * @param ftsName 划分的特征名称
    * @param ftsLabels
    * @return
    */
  def splitByFts(df: DataFrame, ftsName: String, ftsLabels: Array[String]) = {
    val column: Column = col(ftsName)
    ftsLabels.map(ftsvalue => {
      ftsvalue -> df.where(column === ftsvalue).drop(ftsName)
    })
  }

  def fit  = {

    var searchList: List[String] = Nil

    /**
      *
      * @param data  DataFrame
      * @param fNodeName 当前结点的特征名称
      * @return
      */
    def creatTree(data: DataFrame, fNodeName: String = null): DtreeNode = {

      data.persist()

      val datalabels = data.select(labelColName).distinct()
      // 1,若D中实例属于同一类Ck,则T为单节点树,并将类Ck作为结点的类标记,返回T
      val dataCount: Long = datalabels.count()
      if (dataCount == 1) {
        val ck: String = datalabels.head().getString(0)
        DtreeNode(fNodeName, ck, Nil)
      }

      // 2, 若A为空,则T为单节点树,将D中实例树最大的类Ck作为该节点的类标记,返回T
      // 数据特征名称
      val ftSchemas: Array[String] = data.columns
      if (ftSchemas.isEmpty) {
        val ck: String = maxCountLabel(data)
        DtreeNode(fNodeName, ck, Nil)
      }

      // 3,计算信息增益;判断信息增益是否小于阈值 ,小于则置T为单节点树,并将D中是实例数最大的类Ck作为该节点的类标记,返回T
      // 不小于则递归创建树
      val (ftsName, maxgda, ftsLabels) =
        optimalFeatureSel(data, labelColName, ftSchemas)

      searchList = ftsName +: searchList

      val dtreeNode = if (maxgda < threshold) {
        val ck: String = maxCountLabel(data)
        DtreeNode(fNodeName, ck, Nil)
      } else {
        val nodaDfs: Array[(String, DataFrame)] =
          splitByFts(data, ftsName, ftsLabels)
        val nodes: Seq[(String, DtreeNode)] = nodaDfs
          .map(tp => {
            val ftsValue: String = tp._1
            val splitedDf: DataFrame = tp._2
            data.unpersist()
            ftsValue -> creatTree(splitedDf, ftsName)
          })
          .toSeq

        DtreeNode(ftsName, "", nodes)

      }

      dtreeNode

    }

    node = creatTree(data)

    search = searchList.reverse.distinct

  }

  /**
    *
    * @param prediction
    * @param dtnode
    */
  def predict(prediction: Dataset[Row]) = {

    val predictRdd = prediction.rdd.map(row => {
      def finder(node: DtreeNode, flist: List[String]): String = {
        val ftsValue: String = row.getAs[String](flist.head)
        node.label match {
          case "" =>
            val nextNode: DtreeNode = node.nexts.find(_._1 == ftsValue).get._2
            finder(nextNode, flist.tail)
          case _ => node.label
        }

      }
      val res: String = finder(node, search)
      Row.merge(row, Row(res))
    })

    val predictDfSchema =
      prediction.schema.add(StructField("predict", StringType))

    spark.createDataFrame(predictRdd, predictDfSchema)

  }

}


/**
  *  决策树类
  **/
case class DtreeNode(
  ftsName: String, //feature name
  label: String,
  nexts: Seq[(String, DtreeNode)] = Nil  
) {

}

  1. 算法测试
package CH5_DecisionTree

import org.apache.spark.sql._

/**
 * Created by WZZC on 2019/8/23
 **/
object ID3Ruuner {

 def main(args: Array[String]): Unit = {
   val spark = SparkSession
     .builder()
     .appName(s"${this.getClass.getSimpleName}")
     .master("local[*]")
     .getOrCreate()

   import spark.implicits._

   val df: Dataset[Row] = spark.read
     .option("header", true)
     .csv("data/ID3.csv")

   val model: DecisionTreeModel = DecisionTreeModel(df, "label")

   model.fit

   model.predict(df).show()

   spark.stop()
 }

}

数据来源

决策树数据.png

决策树结果查看
DtreeNode(isHouse,,WrappedArray(
	(否,DtreeNode(isWork,,WrappedArray(
		(是,DtreeNode(isWork,是,List())), 
		(否,DtreeNode(isWork,否,List()))))), 
	(是,DtreeNode(isHouse,是,List()))))


+----+------+-------+------+-----+-------+
| age|isWork|isHouse|credit|label|predict|
+----+------+-------+------+-----+-------+
|青年|    否|     否|  一般|   否|     否|
|青年|    否|     否|    好|   否|     否|
|青年|    是|     否|    好|   是|     是|
|青年|    是|     是|  一般|   是|     是|
|青年|    否|     否|  一般|   否|     否|
|中年|    否|     否|  一般|   否|     否|
|中年|    否|     否|    好|   否|     否|
|中年|    是|     是|    好|   是|     是|
|中年|    否|     是|非常好|   是|     是|
|中年|    否|     是|非常好|   是|     是|
|老年|    否|     是|非常好|   是|     是|
|老年|    否|     是|    好|   是|     是|
|老年|    是|     否|    好|   是|     是|
|老年|    是|     否|非常好|   是|     是|
|老年|    否|     否|  一般|   否|     否|
+----+------+-------+------+-----+-------+


决策树.png

参考资料

《统计学习方法》

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值