Random Forest
(一). Review of random forest
Random forest algorithm is a classifier based on two primarily method-bagging and random subspace method
Firstly, I give a review of two main ensemble algorithm: bagging and boosting:
- bagging builds more approximating unbiased models using bootstrap samples and same features, then averaging or voting these models gives final prediction value. It is used for reducing variance.
- boosting starts with a weak learner, and gradually improving it by refitting the data giving higher weights to the misclassified samples. The final classifier is built by weighted voting.
Random forest is a substantial modification of bagging on CART trees. It uses another randomness compared with bagging, which only samples cases with replacement (bootstrap). When splitting at each node in each CART tree, random forest also sample a subset of features without replacement and test only this subset features to find the best performing feature to split the data at this node.
As all we know, averaging B i.i.d random variable with variance
σ2
has variance
σ2B
, but if these B variables are only identically distribution (bagging), the averaging variance is
proof:
(ρ is the correlation of each pair of bootstrap samples), so as inceasing B, the averaging variance is decreasing. As seen, if we can reduce the correlation ρ without increasing σ2 too much, we can reduce the averaging variance. Random forest is to achieve this idea by random selection of input features: before each split, select m<p of the input variables as candidates for splitting, and for regression, m is usually set to [p3] , for classification, m set p√ .
So now we can declare that
- bootstrapping samples is used to reduce the variance of each individual tree.
- random selection subset of features is used to reduce correlation between each pair of bootstrap samples.
pseudo-code for random forest:
- For b=1 to B
(a) Draw a bootstrap sample Z∗ of size N from the training data.
(b) Grow a random-forest tree Tb of the bootstrapped data, by recursively repeating following steps in the terminal node of the tree until the minimum node size nmin is reached.
- Select m variables at random from the p variables.
- Pick the best variable/split-point among the m.
- Split the node into two daughter nodes.
- Output the ensemble of trees {Ti}Bi=1 .
(二). Models random forest can be used
Notice that: random forest just performs well on non-liner model, such as decision tree, but is not suitable for linear model.
Since bagging is an additive ensemble technique and averaging linear model is also linear. Note that fitting linear model is a convex problem, and we can find the best possible solution. With that said, since bagging produces a linear model, it can’t beat the best possible solution. Here we give an example using the sample mean which is linear: suppose
x1,x2,..,xN
are i.i.d.
(μ,σ2).
Let
x¯∗i
be the bootstrap realization of sample mean(
i∈1:B
). And the bagging model for sample mean is:
Here we declare that:
proof:
Then using formula
we can get the variance of bagging is
which says bagging of linear model cannot reduce variance.
(三). OOB(out of bag) estimation error
An important properties of random forest is the use of out-of-bag samples, which are the samples are not bootstrapped in building trees. OOB error is
For each observation zi=(xi,yi) , construct its random forest predictor by averaging only those trees corresponding to bootstrap samples without zi .
Formula to understand the process of computing this error: suppose there have been built B trees
{T1,T2,...TB}
and for each individual sample
xi
, there exits a subset trees
Si
built without it, then we predict label for
xi
using these
Si
trees,
y^i=argmax∑Si1{y^i=k}
, and oob error is the average over all samples:
OOB error is close to N-fold cross validation error. So we do not need to perform cross validation along tree building, and once the OOB error stabilizes, the training can be terminated.[OOB error can be used to select the number of trees need to build.]
(四). Variable importance
There are two ways to evaluate variable importance:
- Gini importance: at each split in each tree, the improvement in the split criteria is the importance measure attributed to the splitting variable, and is accumulated over all the trees separately for each variable. Formula to understand: in each decision tree
Tb,b∈1:B
, the square importance measure for
ℓ
variable is defined as
VI2ℓ(Tb)=∑t=1J−1i^2tI{v(t)=ℓ}
v(t) is the index of variable used to split node t into two daughter nodes and the sum is over theJ−1 internal nodes of the tree. i2t is the split criteria such as squared error risk or Gini impurity. Then averaging over all B trees gives the total squared importance measure of ℓ variable:
VI2ℓ=1B∑b=1BVI2ℓ(Tb)=1B∑b=1B∑t=1J−1i^2tI{v(t)=ℓ} - Permutation importance: using OOB samples to construct a different importance measure. When
Tb
is built, the OOB samples are passed down the tree and record the prediction accuracy, denoted as
Cb=1|OOB(b)|∑i∈OOB(b)I{y^i≠yi}
. Then the values for the
ℓ
th variable are randomly permuted in the OOB samples, and again the prediction accuracy is obtained, denoted as
Cbℓ=1|OOB(b)|∑i∈OOB(b)I{y^i,πℓ≠yi}
. The difference
VIℓ(Tb)=Cb−Cbℓ
is as the permutation importance of
ℓ
th variable in the
Tb
tree. Finally, the importance measure of
ℓ
th variable is the averaging of
VIℓ(Tb)
over B trees.
VIℓ=1B∑i=1BVIℓ(Tb)
summary: Gini importance measures how many times variable ℓ is used to split node and how much useful is this split. Permutation importance measures how much useful of variable ℓ to predict test data.
(五). Proximity plot
In growing a random forest, a N×N matrix is accumulated for the training data. For every tree, any pair of OOB samples sharing a terminal node has there proximity by one. This is to say that: for any pair of training samples, the proximity represents the number of trees which they are classified to the same terminal node in trees built without them.
(六). R code for random forest
Here gives a simple R code for random forest using randomForest package in R.
rm(list=ls(all = TRUE))##remove all objects
#install.packages("randomForest")#insall randomForest package
library(randomForest)#library package
data(iris)#data to analysis
n <- nrow(iris)#number of samples
p <- ncol(iris)#number of variable
test_inx <- sample(1:n, n/5)#use n/5 samples as test data
iris_train <- iris[-test_inx, ]#train data
iris_test <- iris[test_inx, ]#test data
iris_rf <- randomForest(iris_train[,-5], iris_train[,5], ntree = 1000, mtry = sqrt(p), replace = TRUE, importance = TRUE, proximity = TRUE)#model on train data
print(iris_rf)#view result
iris_rf$importance#importance of variables
varImpPlot(iris_rf)#plot variable importance
iris_predict <- predict(iris_rf, iris_test[,-5], type = "response")#predict class labels for test data
table(observed = iris_test[, 5], predicted = iris_predict)#look at prediction result
error <- sum(iris_predict!=iris_test[,5])/length(test_inx)#prediction errro
error