本人的第一篇文章,趁着我们的数据挖掘课设的时间,把实现的决策树代码,拿出来分享下。有很多漏洞和缺陷,还有很多骇客思想的成分,但是总之,能实现,看网上的代码,能用的其实也没几个。废话不多说,直接看代码
特别鸣谢博主skyonefly的代码
附上链接:R语言决策树代码
#####
##### C4.5算法实现决策树
#####
#############################################Part1 基础函数#######################################################
#计算香农熵
calShannonEnt <- function(dataSet,labels){
numEntries<-length(dataSet[,labels])
key<-rep("a",numEntries)
for(i in 1:numEntries)
key[i]<-dataSet[i,labels]
shannonEnt<-0
prob<-table(key)/numEntries
for(i in 1:length(prob))
shannonEnt=shannonEnt-prob[i]*log(prob[i],2)
return(shannonEnt)
}
#划分数据集
splitDataSet <- function(dataSet,axis,value,tempSet = dataSet){
retDataSet = NULL
for(i in 1:nrow(dataSet)){
if(dataSet[i,axis] == value){
tempDataSet = tempSet[i,]
retDataSet = rbind(retDataSet,tempDataSet)
}
}
rownames(retDataSet) = NULL
return (retDataSet)
}
#选择信息增益最大的内部节点
chooseBestFeatureToSplita <- function(dataSet,labels, bestInfoGain){
numFeatures = ncol(dataSet) - 1
baseEntropy = calShannonEnt(dataSet,labels)
#最大信息增益
bestFeature = -1
for(i in 1: numFeatures){
featureLabels = levels(factor(dataSet[,i]))
# featureLabels = as.numeric(featureLabels)
newEntropy = 0.0
SplitInfo = 0.0
for( j in 1:length(featureLabels)){
subDataSet = splitDataSet(dataSet,i,featureLabels[j])
prob = length(subDataSet[,1])*1.0/nrow(dataSet)
newEntropy = newEntropy + prob*calShannonEnt(subDataSet,labels)
SplitInfo = -prob*log2(prob) + SplitInfo
}
infoGain = baseEntropy - newEntropy
GainRadio = infoGain/SplitInfo
if(SplitInfo > 0){
GainRadio = infoGain/SplitInfo
if(GainRadio > bestInfoGain){
bestInfoGain = infoGain
bestFeature = i
}
}
}
return (bestFeature)
}
#返回频数最高的列标签
majorityCnt <- function(classList){
classCount = NULL
count = as.numeric(table(classList))
majorityList = levels(as.factor(classList))
if(length(count) == 1){
return (majorityList[1])
}else{
f = max(count)
return (majorityList[which(count == f)][1])
}
}
#判断类标签是否只有一个因子水平
oneValue <- function(classList){
count = as.numeric(table(classList))
if(length(count) == 1){
return (TRUE)
}else
return (FALSE)
}
#树的打印
printTree <- function(tree){
df <- data.frame()
col <- 1
point <- c()
count <- 0
for(i in 1:length(tree)){
if(rownames(tree)[[i]] == 'labelFeature'){
df[i,col] = tree[[i]]
col = col + 1
count = count + 1
point[count] = col
name