了解了一些决策树的构建算法后,现在学习下随机森林。还是先上一些基本概念:
随机森林是一种比较新的机器学习模型。顾名思义,是用随机的方式建立一个森林,森林里面有很多的决策树组成,随机森林的每一棵决策树之间是没有关联的。在得到森林之后,当有一个新的输入样本进入的时候,就让森林中的每一棵决策树分别进行一下判断,看看这个样本应该属于哪一类(对于分类算法),然后看看哪一类被选择最多,就预测这个样本为那一类,即选举投票。
优点:
a. 在数据集上表现良好,两个随机性的引入,使得随机森林不容易陷入过拟合
b. 在当前的很多数据集上,相对其他算法有着很大的优势,两个随机性的引入,使得随机森林具有很好的抗噪声能力
c. 它能够处理很高维度(feature很多)的数据,并且不用做特征选择,对数据集的适应能力强:既能处理离散型数据,也能处理连续型数据,数据集无需规范化
d. 可生成一个Proximities=(pij)矩阵,用于度量样本之间的相似性: pij=aij/N, aij表示样本i和j出现在随机森林中同一个叶子结点的次数,N随机森林中树的颗数
e. 在创建随机森林的时候,对generlization error使用的是无偏估计
f. 训练速度快,可以得到变量重要性排序(两种:基于OOB误分率的增加量和基于分裂时的GINI下降量
g. 在训练过程中,能够检测到feature间的互相影响
h. 容易做成并行化方法
i. 实现比较简单
说白了,随机森林就是由许多个决策树构成,决策树使用什么算法取决于你。每个决策树构建需要的数据集是总数据集的随机抽取。同时每个抽取出来的数据集也不一定是包含所有特征属性,其含有的特征属性也是随机从总特征属性中随机抽取。随机森林等到所有决策树构建完成后,对样本数据集进行测试分类。最终的结果可以通过简单的投票选择获得,也可以通过复杂的权重计算获得等等。
下面是随机森林Java的简单实现
public class ForestBuilder extends BuilderAbstractImpl {
/** 决策树数量*/
private int treeNum = 0;
/** 随机属性数量*/
private int attributeNum = 0;
/** 构建决策树Builder*/
private Builder builder = null;
public ForestBuilder(int treeNum, Builder builder, int attributeNum) {
this.treeNum = treeNum;
this.builder = builder;
this.attributeNum = attributeNum;
}
@Override
public Object build(Data data) {
ExecutorService pools = Executors.newFixedThreadPool(
Runtime.getRuntime().availableProcessors());
List<Future<TreeNode>> futures = new ArrayList<Future<TreeNode>>();
for (int i = 0; i < treeNum; i++) {
//线程里面去构建决策树
DecisionCallable callable = new DecisionCallable(data, builder, attributeNum);
futures.add(pools.submit(callable));
}
System.out.println("futures size: " + futures.size());
//等待线程创建完决策树
List<TreeNode> results = new ArrayList<TreeNode>();
handleFuture(futures, results);
int futureLen = futures.size();
int resultsLen = results.size();
while (resultsLen < futureLen) {
handleFuture(futures, results);
resultsLen = results.size();
}
pools.shutdown();
return results;
}
private void handleFuture(List<Future<TreeNode>> futures, List<TreeNode> results) {
Iterator<Future<TreeNode>> iterator = futures.iterator();
while (iterator.hasNext()) {
Future<TreeNode> future = iterator.next();
if (future.isDone()) {
try {
results.add(future.get());
iterator.remove();
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
}
class DecisionCallable implements Callable<TreeNode> {
private Data data = null;
private int attributeNum = 0;
private Builder builder = null;
public DecisionCallable(Data data, Builder builder, int attributeNum) {
this.data = data;
this.builder = builder;
this.attributeNum = attributeNum;
}
@Override
public TreeNode call() throws Exception {
Data randomData = DataLoader.loadRandom(data, attributeNum);
Object object = builder.build(randomData);
return null != object ? (TreeNode) object : null;
}
}
public class ForestNode extends Node {
private static final long serialVersionUID = 1L;
private List<TreeNode> treeNodes = null;
public ForestNode(List<TreeNode> treeNodes) {
this.treeNodes = treeNodes;
}
@Override
public Object classify(Data data) {
List<Object[]> results = new ArrayList<Object[]>();
for (TreeNode treeNode : treeNodes) {
Object result = treeNode.classify(data);
if (null != result) {
results.add((Object[]) treeNode.classify(data));
}
}
return DataHandler.vote(results);
}
@Override
public Object classify(Instance... instances) {
List<Object[]> results = new ArrayList<Object[]>();
for (TreeNode treeNode : treeNodes) {
Object result = treeNode.classify(instances);
if (null != result) {
results.add((Object[]) treeNode.classify(instances));
}
}
//投票选择
return DataHandler.vote(results);
}
}