决策树算法都要解决两个问题:
1.选择哪个属性来分裂(越乱,熵越大,越不纯)
2.什么时候树停止生长
C4.5流行的原因:
1 用信息增益率来选择属性分裂
2 构造树的过程中进行剪枝
3 能处理连续型数据和不完整数据
特征A对训练数据数据D的信息增益g(D,A)=集合D的经验熵H(D)-特征A给定情况下的经验条件熵H(D|A)
特征A对训练数据D的信息增益比r(D,A)=g(D,A)/H(D)
例子
library(rpart)
##rpart.control对树进行一些设置
##xval是10折交叉验证
##minsplit是最小分支节点数,这里大于等于20,那么该节点会继续分割下去,否则停止
##minbucket叶子节点最小样本数
##maxdepth树的深度
##cp指某个点的复杂度,对每一步拆分模型的拟合优度必须提高,用来节省剪枝浪费的不必要的时间
ct=rpart.control(xval=10,minsplit=20,cp=0.1)
##na.action缺失数据的处理办法
##method树的末端数据类型选择相应的变量分割方法:连续性method="anova",离散型="class"
##parms用来设置三个参数:先验概率,损失矩阵,分类纯度的度量方法(gini和information)
fit=rpart(y~.,data=data,rpart.control=ct,parms=list(prior=c(0.65,0.35),split="information"))
##rpart包提供了复杂度损失修剪的修剪方法,printcp会告诉分类到每一程,cp是多少
printcp(fit)
#确定cp值
cp=fit$cptable[which.min(fit$cptable[,"xerror"]),"CP"])
fit2=prune(fit,cp=cp)
比较c50 rpart party的性能
从图中可以观察到C5.0的表现最好,而party次之,rpart的效果最差。在本例实验中最大的差距虽然不过0.02,但如果放在kaggle的数据挖掘比赛中,就相当于是一百位名次的差距了。
C5.0算法相对于C4.5有如下几点改进:
速度显著加快
内存使用减少
生成树模型更为简洁
支持boosting方法
支持加权和成本矩阵
支持变量筛选
library(C50)
library(rpart)
library(party)
library(reshape2)
library(ggplot2)
data(churn)
rate.c <- rate.r <-rate.p<- rep(0,100)
for (j in 1:100) {
num <- sample(1:10,nrow(churnTrain),replace=T)
res.c <- res.r <-res.p<- array(0,dim=c(2,2,10))
for ( i in 1:10) {
train <- churnTrain[num!=i, ]
test <- churnTrain[num==i, ]
model.c <- C5.0(churn~.,data=train)
pre <- predict(model.c,test[,-20])
res.c[,,i] <- as.matrix(table(pre,test[ ,20]))
model.p <-ctree(churn~.,data=train)
pre <- predict(model.p,test[,-20])
res.p[,,i] <- as.matrix(table(pre,test[ ,20]))
model.r <- rpart(churn~.,data=train)
pre <- predict(model.r,test[,-20],type='class')
res.r[,,i] <- as.matrix(table(pre,test[ ,20]))
}
table.c <- apply(res.c,MARGIN=c(1,2),sum)
rate.c[j] <- sum(diag(table.c))/sum(table.c)
table.p <- apply(res.p,MARGIN=c(1,2),sum)
rate.p[j] <- sum(diag(table.p))/sum(table.p)
table.r <- apply(res.r,MARGIN=c(1,2),sum)
rate.r[j] <- sum(diag(table.r))/sum(table.r)
}
data <- data.frame(c50=rate.c,rpart=rate.r,party=rate.p)
data.melt <- melt(data)
p <- ggplot(data.melt,aes(variable,value,color=variable))
p + geom_point(position='jitter')+
geom_violin(alpha=0.4)