随机森林的简单实现

近日听了七月天空周博的课。现在对随机森林进行一下,简单的实现。

随机森林(randomforest)是一种利用多个分类树对数据进行判别与分类的方法,它在对数据进行分类的同时,还可以给出各个变量(基因)的重要性评分,评估各个变量在分类中所起的作用。

随机森林是一个最近比较火的算法,它有很多的优点:

a. 在数据集上表现良好,两个随机性的引入,使得随机森林不容易陷入过拟合

b. 在当前的很多数据集上,相对其他算法有着很大的优势,两个随机性的引入,使得随机森林具有很好的抗噪声能力

c. 它能够处理很高维度(feature很多)的数据,并且不用做特征选择,对数据集的适应能力强:既能处理离散型数据,也能处理连续型数据,数据集无需规范化

d. 可生成一个Proximities=pij)矩阵,用于度量样本之间的相似性: pij=aij/N, aij表示样本ij出现在随机森林中同一个叶子结点的次数,N随机森林中树的颗数

e. 在创建随机森林的时候,对generlization error使用的是无偏估计

f. 训练速度快,可以得到变量重要性排序(两种:基于OOB误分率的增加量和基于分裂时的GINI下降量

g. 在训练过程中,能够检测到feature间的互相影响

h. 容易做成并行化方法

i. 实现比较简单


/*
 * RF.h
 *
 *  Created on: Nov 8, 2015
 *      Author: shenjiyi
 */

#ifndef RF_H_
#define RF_H_

#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <cmath>
#include <map>
#include <set>
#include <vector>

using namespace std;

#define MAX_RANDOM 10000
#define MAXN 10000

struct Data{
	vector<int> x;
	int y;
};

typedef vector<Data> Datas;

struct SubTreeType {
	double v;
	Datas right, left;
};

struct TreeNodeType {
	double splitValue;
	double splitFeature;
	int winClass;
	Datas samp;
};

typedef vector<TreeNodeType> Tree;

typedef vector<Tree> Forest;

struct Option {
	int treeNumber;
	int bagNumber;
	int depth;
	int bestSelect;
};

class RF {
private:

	Forest forest;

	double calGini(Datas &data, int &winClass) {
		map<int, int> mp; mp.clear();

		for (auto d : data) {
			mp[d.y]++;
		}
		double sum = 0;
		int winNumber = 0;
		for (auto k : mp) {
//			cout << k.first << "\t" << k.second << endl;
			int key = k.first;
			sum = sum + mp[key] * mp[key] /(data.size() + 0.0)/ (data.size() + 0.0);
			if (mp[key] > winNumber) {
				winClass = key;
				winNumber = mp[key];
			}
		}
		return 1 - sum;
	}

	double splitGini(Datas &left, Datas &right) {
		int totalNumber = left.size() + right.size();
		int idx = -1;
		double sum = (left.size() + 0.0 / totalNumber) * calGini(left, idx)
		+ (right.size() + 0.0 / totalNumber) * calGini(right, idx);
		return sum;
	}

	double randomf(double a,double b){
		return (rand()%(int)((b-a)*MAX_RANDOM))
			/(double)MAX_RANDOM+a;
	}

	double randomi(double a, double b) {
		return floor(randomf(a, b));
	}

	double randoms() {
		return (rand()%(int)(MAX_RANDOM))
			/(double)MAX_RANDOM;
	}

	void randomSplit(Datas &data, int feature, SubTreeType& subTree) {
		int a = randomi(0, data.size());
		int b = randomi(0, data.size());
		while (a == b) {b = randomi(0, data.size());}
		double s = randoms();
		double splitValue = s * data[a].x[feature]
			+ (1 - s) * data[b].x[feature];
//		cout << "a = "<<a<<" b = "<<b <<" s= " << s << " splitvalue= " << splitValue << endl;
		subTree.v = splitValue;
		for (int i = 0; i < data.size(); ++i) {
			if (data[i].x[feature] > splitValue) {
				subTree.right.push_back(data[i]);
			} else {
				subTree.left.push_back(data[i]);
			}
		}
	}

	void createSingleTree(Datas &data, int depth, int bestSelect, Tree& singleTree) {

		int featureNumber = data[0].x.size();
		int allNumber = pow(2, depth + 1) - 1;
		int nodeNumber = pow(2, depth) - 1;
		singleTree.clear();
		for (int i = 0; i < allNumber; ++i) {
			TreeNodeType tmp;
			singleTree.push_back(tmp);
			singleTree[i].splitValue = -1;
			singleTree[i].winClass = -1;
			singleTree[i].splitFeature = -1;
			singleTree[i].samp.clear();
		}

		for (int i = 0; i < data.size(); ++i) {
			singleTree[0].samp.push_back(data[i]);
		}

		for (int i = 0; i < nodeNumber; i++) {
			Datas &samples = singleTree[i].samp;
			if (samples.size() == 0 || samples.size() == 1) {
				continue;
			}

			int feature = randomi(0, featureNumber);
			int idx;

			double bestGini = calGini(samples, idx);

//			cout << "bestSelect" << bestSelect << "bestGini " << bestGini << endl;
			SubTreeType *bestTree = NULL;
			SubTreeType subTree;
			for (int j = 0; j < bestSelect && bestGini > 0; ++j) {
				subTree.left.clear(); subTree.right.clear();
				subTree.v = -1;
				randomSplit(samples, feature, subTree);
				double newGini = splitGini(subTree.left, subTree.right);
				if (newGini < bestGini) {
					bestGini = newGini;
					bestTree = &subTree;
				}
			}


			if (bestTree != NULL) {
				singleTree[i].splitValue = bestTree->v;
				singleTree[i].splitFeature = feature;
				singleTree[i * 2 + 1].samp = move(bestTree->left);
				singleTree[i * 2 + 2].samp = move(bestTree->right);
			}
		}

//		cout << "sss" << endl;
		for (int i = 0; i < allNumber; ++i) {
			if (singleTree[i].splitValue == -1 && singleTree[i].samp.size() > 0) {
				int idx = -1;
				calGini(singleTree[i].samp, idx);
				singleTree[i].winClass = idx;
			}
		}
	}

	Datas bagging(Datas &data, int bagNumber) {
		Datas bag; bag.clear();
		for (int i = 0; i < bagNumber; ++i) {
			int n = randomi(0, data.size());
			bag.push_back(data[n]);
		}
		return bag;
	}

	void createForest(Datas &data, Option option) {
		int treeNumber = option.treeNumber;
		int bagNumber = option.bagNumber;
		int depth = option.depth;
		int bestSelect = option.bestSelect;

		forest.clear();
		Tree tmp;
		for (int i = 0; i < treeNumber; ++i) {
			Datas subData = bagging(data, bagNumber);
			createSingleTree(subData, depth, bestSelect, tmp);
			forest.push_back(tmp);
		}
	}

	int predWithTree(Tree &tree, vector<int> &x) {
	//	cout <<"tree size " << tree.size() << endl;
		for (int i = 0;;) {
			if (i >= tree.size()) {
				return -1;
			}
			if (tree[i].winClass != -1) {
			//	cout << "tree winclass " << tree[i].winClass << endl;
				return tree[i].winClass;
			}
			if (x[tree[i].splitFeature] < tree[i].splitValue) {
				i = 2 * i + 1;
			} else {
				i = 2 * i + 2;
			}
		}
		return -1;
	}
public:

//Gini
	RF() {
		forest.clear();
	}
	int predWithForest(vector<int> x, int &prob) {
		map<int, int> mp; mp.clear();

		for (int i = 0; i < forest.size(); ++i) {
			int pred = predWithTree(forest[i], x);
			if (pred != -1) {
				mp[pred]++;
			}
		}
		int winClass = -1, winNumber = -1;
		for (auto v : mp) {
			cout << "first " << v.first << " second " << v.second << endl;
			if (v.second > winNumber) {
				winClass = v.first;
				winNumber = v.second;
				prob = winNumber;
			}
		}
		return winClass;
	}

	void print() {
		cout << "forest size " << forest.size() << endl;

		for (auto ts : forest) {

			for (auto t : ts) {
				cout << "[";
				cout << "(" << t.splitFeature << "," << t.splitValue << "," << t.winClass << ")" << "|";
				for (auto tt:t.samp) {
					cout << tt.x[0] <<"," <<tt.y <<" ";
				}
				cout << "]";
			}
			cout << endl;
		}
	}

	void Training(Datas &data, int treeNumber, int bagNumber, int depth, int bestSelect) {
		Option option;
		option.treeNumber = treeNumber;
		option.bagNumber = 	bagNumber;
		option.depth = depth;
		option.bestSelect = bestSelect;

		createForest(data, option);
	}
};


#endif /* RF_H_ */


展开阅读全文

没有更多推荐了,返回首页