spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(四)
spark源码分析之随机森林(Random Forest)(五)
Spark在mllib中实现了tree相关的算法,决策树DT(DecisionTree),随机森林RF(RandomForest),GBDT(Gradient Boosting Decision Tree),其基础都是RF,DT是RF一棵树时的情况,而GBDT则是循环构建DT,GBDT与DT的代码是非常简单明了的,本文将分成五部分分别对RF的源码进行分析,介绍spark在实现过程中使用的一些技巧。
1. 决策树与随机森林
首先对决策树和随机森林进行简单的回顾。
1.1. 决策树
在决策树的训练中,如上图所示,就是从根节点开始,不断的分裂,直到触发截止条件,在节点的分裂过程中要解决的问题其实就2个
- 分裂点:一般就是遍历所有特征的所有特征值,选取impurity最大的分成左右孩子节点,impurity的选取有信息熵(分类),最小均方差(回归)等方法
- 预测值:一般取当前最多的class(分类)或者取均值(回归)
1.2. 随机森林
随机森林就是构建多棵决策树投票,在构建多棵树过程中,引入随机性,一般体现在两个方面,一是每棵树使用的样本进行随机抽样,分为有放回和无放回抽样。二是对每棵树使用的特征集进行抽样,使用部分特征训练。
在训练过程中,如果单机内存能放下所有样本,可以用多线程同时训练多棵树,树之间的训练互不影响。
2. spark RF优化策略
spark在实现RF时,使用了一些优化技巧,提高训练效率。
2.1. 逐层训练
当样本量过大,单机无法容纳时,只能采用分布式的训练方法,数据是在集群中的多台机器存放,如果按照单机的方法,每棵树完全独立访问样本数据,则样本数据的访问次数为数的个数k*每棵树的节点数N,相当于深度遍历。在spark的实现中,因为数据存放在不同的机器上,频繁的访问数据效率非常低,因此采用广度遍历的方法,每次构造所有树的一层,例如如果要训练10棵树,第一次构造所有树的第一层根节点,第二次构造所有深度为2的节点,以此类推,这样访问数据的次数降为树的最大深度,大大减少了机器之间的通信,提高训练效率。
2.2. 样本抽样
当样本存在连续特征时,其可能的取值可能是无限的,存储其可能出现的值占用较大空间,因此spark对样本进行了抽样,抽样数量