近日听了七月天空周博的课。现在对随机森林进行一下,简单的实现。
随机森林(randomforest)是一种利用多个分类树对数据进行判别与分类的方法,它在对数据进行分类的同时,还可以给出各个变量(基因)的重要性评分,评估各个变量在分类中所起的作用。
随机森林是一个最近比较火的算法,它有很多的优点:
a. 在数据集上表现良好,两个随机性的引入,使得随机森林不容易陷入过拟合
b. 在当前的很多数据集上,相对其他算法有着很大的优势,两个随机性的引入,使得随机森林具有很好的抗噪声能力
c. 它能够处理很高维度(feature很多)的数据,并且不用做特征选择,对数据集的适应能力强:既能处理离散型数据,也能处理连续型数据,数据集无需规范化
d. 可生成一个Proximities=(pij)矩阵,用于度量样本之间的相似性: pij=aij/N, aij表示样本i和j出现在随机森林中同一个叶子结点的次数,N随机森林中树的颗数
e. 在创建随机森林的时候,对generlization error使用的是无偏估计
f. 训练速度快,可以得到变量重要性排序(两种:基于OOB误分率的增加量和基于分裂时的GINI下降量
g. 在训练过程中,能够检测到feature间的互相影响
h. 容易做成并行化方法
i. 实现比较简单
/*
* RF.h
*
* Created on: Nov 8, 2015
* Author: shenjiyi
*/
#ifndef RF_H_
#define RF_H_
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <cmath>
#include <map>
#include <set>
#include <vector>
using namespace std;
#define MAX_RANDOM 10000
#define MAXN 10000
struct Data{
vector<int> x;
int y;
};
typedef vector<Data> Datas;
struct SubTreeType {
double v;
Datas right, left;
};
struct TreeNodeType {
double splitValue;
double splitFeature;
int winClass;
Datas samp;
};
typedef vector<TreeNodeType> Tree;
typedef vector<Tree> Forest;
struct Option {
int treeNumber;
int bagNumber;
int depth;
int bestSelect;
};
class RF {
private:
Forest forest;
double calGini(Datas &data, int &winClass) {
map<int, int> mp; mp.clear();
for (auto d : data) {
mp[d.y]++;
}
double sum = 0;
int winNumber = 0;
for (auto k : mp) {
// cout << k.first << "\t" << k.second << endl;
int key = k.first;
sum = sum + mp[key] * mp[key] /(data.size() + 0.0)/ (data.size() + 0.0);
if (mp[key] > winNumber) {
winClass = key;
winNumber = mp[key];
}
}
return 1 - sum;
}
double splitGini(Datas &left, Datas &right) {
int totalNumber = left.size() + right.size();
int idx = -1;
double sum = (left.size() + 0.0 / totalNumber) * calGini(left, idx)
+ (right.size() + 0.0 / totalNumber) * calGini(right, idx);
return sum;
}
double randomf(double a,double b){
return (rand()%(int)((b-a)*MAX_RANDOM))
/(double)MAX_RANDOM+a;
}
double randomi(double a, double b) {
return floor(randomf(a, b));
}
double randoms() {
return (rand()%(int)(MAX_RANDOM))
/(double)MAX_RANDOM;
}
void randomSplit(Datas &data, int feature, SubTreeType& subTree) {
int a = randomi(0, data.size());
int b = randomi(0, data.size());
while (a == b) {b = randomi(0, data.size());}
double s = randoms();
double splitValue = s * data[a].x[feature]
+ (1 - s) * data[b].x[feature];
// cout << "a = "<<a<<" b = "<<b <<" s= " << s << " splitvalue= " << splitValue << endl;
subTree.v = splitValue;
for (int i = 0; i < data.size(); ++i) {
if (data[i].x[feature] > splitValue) {
subTree.right.push_back(data[i]);
} else {
subTree.left.push_back(data[i]);
}
}
}
void createSingleTree(Datas &data, int depth, int bestSelect, Tree& singleTree) {
int featureNumber = data[0].x.size();
int allNumber = pow(2, depth + 1) - 1;
int nodeNumber = pow(2, depth) - 1;
singleTree.clear();
for (int i = 0; i < allNumber; ++i) {
TreeNodeType tmp;
singleTree.push_back(tmp);
singleTree[i].splitValue = -1;
singleTree[i].winClass = -1;
singleTree[i].splitFeature = -1;
singleTree[i].samp.clear();
}
for (int i = 0; i < data.size(); ++i) {
singleTree[0].samp.push_back(data[i]);
}
for (int i = 0; i < nodeNumber; i++) {
Datas &samples = singleTree[i].samp;
if (samples.size() == 0 || samples.size() == 1) {
continue;
}
int feature = randomi(0, featureNumber);
int idx;
double bestGini = calGini(samples, idx);
// cout << "bestSelect" << bestSelect << "bestGini " << bestGini << endl;
SubTreeType *bestTree = NULL;
SubTreeType subTree;
for (int j = 0; j < bestSelect && bestGini > 0; ++j) {
subTree.left.clear(); subTree.right.clear();
subTree.v = -1;
randomSplit(samples, feature, subTree);
double newGini = splitGini(subTree.left, subTree.right);
if (newGini < bestGini) {
bestGini = newGini;
bestTree = &subTree;
}
}
if (bestTree != NULL) {
singleTree[i].splitValue = bestTree->v;
singleTree[i].splitFeature = feature;
singleTree[i * 2 + 1].samp = move(bestTree->left);
singleTree[i * 2 + 2].samp = move(bestTree->right);
}
}
// cout << "sss" << endl;
for (int i = 0; i < allNumber; ++i) {
if (singleTree[i].splitValue == -1 && singleTree[i].samp.size() > 0) {
int idx = -1;
calGini(singleTree[i].samp, idx);
singleTree[i].winClass = idx;
}
}
}
Datas bagging(Datas &data, int bagNumber) {
Datas bag; bag.clear();
for (int i = 0; i < bagNumber; ++i) {
int n = randomi(0, data.size());
bag.push_back(data[n]);
}
return bag;
}
void createForest(Datas &data, Option option) {
int treeNumber = option.treeNumber;
int bagNumber = option.bagNumber;
int depth = option.depth;
int bestSelect = option.bestSelect;
forest.clear();
Tree tmp;
for (int i = 0; i < treeNumber; ++i) {
Datas subData = bagging(data, bagNumber);
createSingleTree(subData, depth, bestSelect, tmp);
forest.push_back(tmp);
}
}
int predWithTree(Tree &tree, vector<int> &x) {
// cout <<"tree size " << tree.size() << endl;
for (int i = 0;;) {
if (i >= tree.size()) {
return -1;
}
if (tree[i].winClass != -1) {
// cout << "tree winclass " << tree[i].winClass << endl;
return tree[i].winClass;
}
if (x[tree[i].splitFeature] < tree[i].splitValue) {
i = 2 * i + 1;
} else {
i = 2 * i + 2;
}
}
return -1;
}
public:
//Gini
RF() {
forest.clear();
}
int predWithForest(vector<int> x, int &prob) {
map<int, int> mp; mp.clear();
for (int i = 0; i < forest.size(); ++i) {
int pred = predWithTree(forest[i], x);
if (pred != -1) {
mp[pred]++;
}
}
int winClass = -1, winNumber = -1;
for (auto v : mp) {
cout << "first " << v.first << " second " << v.second << endl;
if (v.second > winNumber) {
winClass = v.first;
winNumber = v.second;
prob = winNumber;
}
}
return winClass;
}
void print() {
cout << "forest size " << forest.size() << endl;
for (auto ts : forest) {
for (auto t : ts) {
cout << "[";
cout << "(" << t.splitFeature << "," << t.splitValue << "," << t.winClass << ")" << "|";
for (auto tt:t.samp) {
cout << tt.x[0] <<"," <<tt.y <<" ";
}
cout << "]";
}
cout << endl;
}
}
void Training(Datas &data, int treeNumber, int bagNumber, int depth, int bestSelect) {
Option option;
option.treeNumber = treeNumber;
option.bagNumber = bagNumber;
option.depth = depth;
option.bestSelect = bestSelect;
createForest(data, option);
}
};
#endif /* RF_H_ */