CH5_决策树(ID3)及其spark实现
1.决策树定义
决策树是一种基本的分类与回归方法,分类决策树模型是一种描述对实例进行分类的树状结构。决策树由结点(Node)和有向边(directed edge)组成。结点有两种类型:内部结点和叶结点。内部结点表示一个特征或属性,叶结点表示一个类。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QlIHY9Lt-1589348657852)(E3C88258B82E417C99458428C39F155B)]
决策树可以看成是一个if-then规则的集合,也可以表示给定特征条件下类的条件概率分布。
决策树的学习目标是根据给定的训练数据集构建一个决策树模型,使它能够对实例进行正确的分类。
2.决策树特征选择
2.1 信息增益
- 熵
熵是表示随机变量不确定性的度量;设X是一个取有限个值的离散随机变量,其概率分布为:
P ( X = x i ) = p i , i = 1 , 2 , ⋅ ⋅ ⋅ n P(X=x_i)=p_i , i = 1,2,···n P(X=xi)=pi,i=1,2,⋅⋅⋅n
随机变量的熵定义为:
H ( X ) = − ∑ p i l o g p i H(X) = -\sum{p_ilogp_i} H(X)=−∑pilogpi
- 信息增益
信息增益:特征A对训练数据集D的信息增益g(D,A),定义为集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差,即:
g ( D , A ) = H ( D ) − H ( D ∣ A ) g(D,A) = H(D) - H(D|A) g(D,A)=H(D)−H(D∣A)
根据信息增益准则选择特征的方法是:选择信息增益最大的特征
2.2 信息增益比
特 征 A 特 征 A 对 训 练 数 据 集 D 的 信 息 增 益 比 g R ( D , A ) 定 义 为 信 息 增 益 g ( D , A ) 与 训 练 集 D 关 于 特 征 A 的 值 的 熵 H A ( D ) 之 比 , 即 : g R ( D , A ) = g ( D , A ) H A ( D ) 其 中 , H A ( D ) = − ∑ ∣ D i ∣ ∣ D ∣ l o g 2 ∣ D i ∣ ∣ D ∣ , n 是 特 征 A 的 取 值 个 数 特征A特征A对训练数据集D的信息增益比g_R(D,A)定义为信息增益g(D,A) \\ 与训练集D关于特征A的值的熵H_A(D)之比,即:\\ g_R(D,A) = \frac{g(D,A)}{H_A(D)} \\ 其中,H_A(D) = -\sum \frac{|D_i|}{|D|}log_2\frac{|D_i|}{|D|},n是特征A的取值个数 特征A特征A对训练数据集D的信息增益比gR(D,A)定义为信息增益g(D,A)与训练集D关于特征A的值的熵HA(D)之比,即:gR(D,A)=HA(D)g(D,A)其中,HA(D)=−∑∣D∣∣Di∣log2∣D∣∣Di∣,n是特征A的取值个数
根据信息增益准则选择特征的方法是:选择信息增益比最大的特征
2.3 基尼系数
分
类
问
题
中
,
假
如
有
K
个
类
,
样
本
点
属
于
第
k
类
的
概
率
为
P
k
,
则
概
率
分
布
的
基
尼
系
数
定
义
为
:
分类问题中,假如有K个类,样本点属于第k类的概率为P_k,则概率分布的基尼系数定义为:
分类问题中,假如有K个类,样本点属于第k类的概率为Pk,则概率分布的基尼系数定义为:
G
i
n
i
(
p
)
=
∑
k
=
1
K
P
k
(
1
−
P
k
)
=
1
−
∑
k
=
1
K
P
k
2
Gini(p) = \sum_{k=1}^{K} P_k(1-P_k) = 1-\sum_{k=1}^{K} P_k^2
Gini(p)=k=1∑KPk(1−Pk)=1−k=1∑KPk2
对于二分类问题,若样本点属于第一个类的概率是p,则概率分布的基尼系数为:
G
i
n
i
(
p
)
=
2
p
(
1
−
p
)
Gini(p) = 2p(1-p)
Gini(p)=2p(1−p)
对于给定的样本集合D,其基尼系数为:
G i n i ( p ) = 1 − ∑ ( ∣ C k ∣ ∣ D ∣ ) 2 这 里 , C k 是 D 中 属 于 第 k 类 的 样 本 子 集 , K 是 类 的 个 数 Gini(p) = 1-\sum (\frac{|C_k|}{|D|})^2 \\ 这里,C_k是D中属于第k类的样本子集,K是类的个数 Gini(p)=1−∑(∣D∣∣Ck∣)2这里,Ck是D中属于第k类的样本子集,K是类的个数
3.几种常用的决策树算法
3.1 ID3算法:应用信息增益准则选择特征
3.2 C4.5算法:应用信息增益比准则选择特征
3.3 CART算法(分类与回归树):分类树用基尼系数选择最优特征
4. 决策树(ID3)模型
package CH5_DecisionTree
import org.apache.spark.sql.functions.{col, count, log2, sum}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
/**
* Created by WZZC on 2019/12/6
**/
case class DecisionTreeModel(data: DataFrame,
labelColName: String,
threshold: Double = 1e-2) {
private val spark = data.sparkSession
import spark.implicits._
var node: DtreeNode = _
var search: List[String] = _
/**
* 获取实例数最大的类Ck
*
* @param dataFrame
* @return
*/
def maxCountLabel(dataFrame: DataFrame) = {
dataFrame
.select(labelColName)
.groupBy(labelColName)
.agg(count(labelColName) as "ck")
.collect()
.map(row => (row.getString(0), row.getLong(1)))
.maxBy(_._2)
._1
}
/**
*最优特征选择(ID3)
*
* @param df dataframe
* @param labelCol ck类S Colname
* @param ftSchemas 特征集合
* @return
*/
def optimalFeatureSel(df: DataFrame,
labelCol: String,
ftSchemas: Array[String]) = {
// 数据格式转换,行转列
val ftsCount = df
.flatMap(row => {
val label = row.getAs[String](labelColName)
(0 until row.length).map(i => {
(label, ftSchemas(i), row.getString(i))
})
})
.toDF("label", "ftsName", "ftsValue")
.groupBy("label", "ftsName", "ftsValue")
.agg(count("label") as "lcount")
.repartition($"ftsName")
.cache()
//impiricalEntropy 经验熵
val preProbdf = ftsCount
.where($"ftsName" === labelCol)
.cache()
val dfcount: Double = preProbdf
.agg(sum("lcount") as "lsum")
.head()
.getAs[Long]("lsum")
.toDouble
val impiricalEntropy: Double = preProbdf
.withColumn("pi", $"lcount" / dfcount)
.withColumn("hd", log2($"pi") * (-$"pi"))
.agg(sum($"hd") as "hd")
.collect()
.head
.getAs[Double]("hd")
//经验条件熵
val ftsValueCount: DataFrame = ftsCount
.filter($"ftsName" =!= labelCol)
.groupBy($"ftsName", $"ftsValue")
.agg(sum("lcount") as "lsum")
val cens: DataFrame = ftsCount
.join(ftsValueCount, Seq("ftsName", "ftsValue"))
.orderBy($"ftsName", $"label")
.withColumn("cpi", $"lcount".cast(DoubleType) / $"lsum")
.withColumn("pi", $"lsum" / dfcount)
.withColumn("gda", -$"cpi" * log2($"cpi") * $"pi")
.groupBy($"ftsName")
.sum("gda")
// 信息增益 ->最大信息增益
val (ftsName, maxGda) = cens
.withColumn("xxzy", -$"sum(gda)" + impiricalEntropy)
.collect()
.map(row => {
val ftsName = row.getString(0)
val maxgda = row.getDouble(2)
(ftsName, maxgda)
})
.maxBy(_._2)
val ftsLabels: Array[String] = ftsValueCount
.where($"ftsName" === ftsName)
.select($"ftsValue")
.collect()
.map(_.getString(0))
ftsCount.unpersist()
preProbdf.unpersist()
(ftsName, maxGda, ftsLabels)
}
/**
* 数据按照特征的值划分
*
* @param df dataframe
* @param ftsName 划分的特征名称
* @param ftsLabels
* @return
*/
def splitByFts(df: DataFrame, ftsName: String, ftsLabels: Array[String]) = {
val column: Column = col(ftsName)
ftsLabels.map(ftsvalue => {
ftsvalue -> df.where(column === ftsvalue).drop(ftsName)
})
}
def fit = {
var searchList: List[String] = Nil
/**
*
* @param data DataFrame
* @param fNodeName 当前结点的特征名称
* @return
*/
def creatTree(data: DataFrame, fNodeName: String = null): DtreeNode = {
data.persist()
val datalabels = data.select(labelColName).distinct()
// 1,若D中实例属于同一类Ck,则T为单节点树,并将类Ck作为结点的类标记,返回T
val dataCount: Long = datalabels.count()
if (dataCount == 1) {
val ck: String = datalabels.head().getString(0)
DtreeNode(fNodeName, ck, Nil)
}
// 2, 若A为空,则T为单节点树,将D中实例树最大的类Ck作为该节点的类标记,返回T
// 数据特征名称
val ftSchemas: Array[String] = data.columns
if (ftSchemas.isEmpty) {
val ck: String = maxCountLabel(data)
DtreeNode(fNodeName, ck, Nil)
}
// 3,计算信息增益;判断信息增益是否小于阈值 ,小于则置T为单节点树,并将D中是实例数最大的类Ck作为该节点的类标记,返回T
// 不小于则递归创建树
val (ftsName, maxgda, ftsLabels) =
optimalFeatureSel(data, labelColName, ftSchemas)
searchList = ftsName +: searchList
val dtreeNode = if (maxgda < threshold) {
val ck: String = maxCountLabel(data)
DtreeNode(fNodeName, ck, Nil)
} else {
val nodaDfs: Array[(String, DataFrame)] =
splitByFts(data, ftsName, ftsLabels)
val nodes: Seq[(String, DtreeNode)] = nodaDfs
.map(tp => {
val ftsValue: String = tp._1
val splitedDf: DataFrame = tp._2
data.unpersist()
ftsValue -> creatTree(splitedDf, ftsName)
})
.toSeq
DtreeNode(ftsName, "", nodes)
}
dtreeNode
}
node = creatTree(data)
search = searchList.reverse.distinct
}
/**
*
* @param prediction
* @param dtnode
*/
def predict(prediction: Dataset[Row]) = {
val predictRdd = prediction.rdd.map(row => {
def finder(node: DtreeNode, flist: List[String]): String = {
val ftsValue: String = row.getAs[String](flist.head)
node.label match {
case "" =>
val nextNode: DtreeNode = node.nexts.find(_._1 == ftsValue).get._2
finder(nextNode, flist.tail)
case _ => node.label
}
}
val res: String = finder(node, search)
Row.merge(row, Row(res))
})
val predictDfSchema =
prediction.schema.add(StructField("predict", StringType))
spark.createDataFrame(predictRdd, predictDfSchema)
}
}
/**
* 决策树类
**/
case class DtreeNode(
ftsName: String, //feature name
label: String,
nexts: Seq[(String, DtreeNode)] = Nil
) {
}
- 算法测试
package CH5_DecisionTree
import org.apache.spark.sql._
/**
* Created by WZZC on 2019/8/23
**/
object ID3Ruuner {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName(s"${this.getClass.getSimpleName}")
.master("local[*]")
.getOrCreate()
import spark.implicits._
val df: Dataset[Row] = spark.read
.option("header", true)
.csv("data/ID3.csv")
val model: DecisionTreeModel = DecisionTreeModel(df, "label")
model.fit
model.predict(df).show()
spark.stop()
}
}
数据来源
决策树结果查看
DtreeNode(isHouse,,WrappedArray(
(否,DtreeNode(isWork,,WrappedArray(
(是,DtreeNode(isWork,是,List())),
(否,DtreeNode(isWork,否,List()))))),
(是,DtreeNode(isHouse,是,List()))))
+----+------+-------+------+-----+-------+
| age|isWork|isHouse|credit|label|predict|
+----+------+-------+------+-----+-------+
|青年| 否| 否| 一般| 否| 否|
|青年| 否| 否| 好| 否| 否|
|青年| 是| 否| 好| 是| 是|
|青年| 是| 是| 一般| 是| 是|
|青年| 否| 否| 一般| 否| 否|
|中年| 否| 否| 一般| 否| 否|
|中年| 否| 否| 好| 否| 否|
|中年| 是| 是| 好| 是| 是|
|中年| 否| 是|非常好| 是| 是|
|中年| 否| 是|非常好| 是| 是|
|老年| 否| 是|非常好| 是| 是|
|老年| 否| 是| 好| 是| 是|
|老年| 是| 否| 好| 是| 是|
|老年| 是| 否|非常好| 是| 是|
|老年| 否| 否| 一般| 否| 否|
+----+------+-------+------+-----+-------+
参考资料
《统计学习方法》