随机森林是由多棵树组成的分类或回归方法。主要思想来源于bagging算法,bagging技术思想主要是给定一弱分类器及训练集,让该学习算法训练多轮,每轮的训练集由原始训练集中有放回的随机抽取,大小一般跟原始训练集相当,这样依次训练多个弱分类器,最终的分类由这些弱分类器组合,对于分类问题一般采用多数投票法,对于回归问题一般采用简单平均法。随机森林在bagging的基础上,每个弱分类器都是决策树,决策树的生成过程中中,在属性的选择上增加了依一定概率选择属性,在这些属性中选择最佳属性及分割点,传统做法一般是全部属性中去选择最佳属性,这样随机森林有了样本选择的随机性,属性选择的随机性,这样一来增加了每个分类器的差异性、不稳定性及一定程度上避免每个分类器的过拟合(一般决策树有过拟合现象),由此组合分类器增加了最终的泛化能力。下面是代码的简单实现
/**
* 随机森林 回归问题
* @author ysh 1208706282
*
*/
public class randomforest {
list msamples;
list mcarts;
double mfeaturerate;
int mmaxdepth;
int mminleaf;
random mrandom;
/**
* 加载数据 回归树
* @param path
* @param regex
* @throws exception
*/
public void loaddata(string path,string regex) throws exception{
msamples = new arraylist();
bufferedreader reader = new bufferedreader(new filereader(path));
string line = null;
string splits[] = null;
sample sample = null;
while(null != (line=reader.readline())){
splits = line.split(regex);
sample = new sample();
sample.label = double.valueof(splits[0]);
sample.feature = new arraylist(splits.length-1);
for(int i=0;i
sample.feature.add(new double(splits[i+1]));
}
msamples.add(sample);
}
reader.close();
}
public void train(int iters){
mcarts = new arraylist(iters);
cart cart = null;
for(int iter=0;iter
cart = new cart();
cart.mfeaturerate = mfeaturerate;
cart.mmaxdepth = mmaxdepth;
cart.mminleaf = mminleaf;
cart.mrandom = mrandom;
list s = new arraylist(msamples.size());
for(int i=0;i
s.add(msamples.get(cart.mrandom.nextint(msamples.size())));
}
cart.setdata(s);
cart.train();
mcarts.add(cart);
system.out.println("iter: "+iter);
s = null;
}
}
/**
* 回归问题简单平均法 分类问题多数投票法
* @param sample
* @return
*/
public double classify(sample sample){
double val = 0;
for(cart cart:mcarts){
val += cart.classify(sample);
}
return val/mcarts.size();
}
/**
* @param args
* @throws exception
*/
public static void main(string[] args) throws exception {
// todo auto-generated method stub
randomforest forest = new randomforest();
forest.loaddata("f:/2016-contest/20161001/train_data_1.csv", ",");
forest.mfeaturerate = 0.8;
forest.mmaxdepth = 3;
forest.mminleaf = 1;
forest.mrandom = new random();
forest.mrandom.setseed(100);
forest.train(100);
list samples = cart.loadtestdata("f:/2016-contest/20161001/valid_data_1.csv", true, ",");
double sum = 0;
for(sample s:samples){
double val = forest.classify(s);
sum += (val-s.label)*(val-s.label);
system.out.println(val+" "+s.label);
}
system.out.println(sum/samples.size()+" "+sum);
system.out.println(system.currenttimemillis());
}
}
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持萬仟网。
希望与广大网友互动??
点此进行留言吧!