1. 前言
本文主要讲解CART决策树代码实现的细节,对于想了解决策树原理的同学建议可以去观看台大林轩田教授的视频,他对于决策树以及接下来要利用决策树生成的随机森林算法都讲解的非常好,下面附上链接。https://www.bilibili.com/video/av6991226/#page=34
建议阅读顺序:先阅读源代码,再来看源码关键方法的讲解,源码地址RRdmlearning/Decision-Tree
2. 源码讲解
这部分我争取用简单易懂的语言,阐述源码中关键函数的作用与实现细节。
2.1 calculateDiffCount()
def calculateDiffCount(datas):
#将输入的数据汇总(input dataSet)
#return results Set{type1:type1Count,type2:type2Count ... typeN:typeNCount}
results = {}
for data in datas:
#data[-1] means dataType
if data[-1] not in results:
results[data[-1]] = 1
else:
results[data[-1]] += 1
return results
该函数是计算gini值的辅助函数,假设输入的dataSet为为['A', 'B', 'C', 'A', 'A', 'D'],则输出为['A':3,' B':1, 'C':1, 'D':1],这样分类统计dataSet中每个类别的数量
2.2 gini()
def gini(rows):
#计算gini值(Calculate GINI)
length = len(rows)
results = calculateDiffCount(rows)
imp = 0.0
for i in results:
imp += results[i]/length * results[i]/length
return 1 - imp
这边我们用评判树纯度的方式为gini值,大家也可以使用信息熵等其他方式,只需写一个类似的函数即可
下面我们来看一下gini值的计算方式,不懂公式由来的同学可以看我上面推荐林轩田教授的视频:
gini()中调用了calculateDiffCoun()可以帮我们快速计算出gini值
2.3 splitDatas()
def splitDatas(rows,value,column):
#根据条件分离数据集(splitDatas by value,column)
#return 2 part(list1,list2)
list1 = []
list2 = []
if(isinstance(value,int) or isinstance(value,float)): #for int and float type
for row in rows:
if (row[column] >= value):list1.append(row)
else:list2.append(row)
else: #for String type
for row in rows:
if row[column] == value:list1.append(row)
else:list2.append(row)
return (list1,list2)
这个函数的作用是利用给定的数据(rows),要利用哪个特征切分(column),切分的标准(value)来将数据切分成两份,在下面生成树的过程中会一直循环调用这个函数与gini()来切分成最好的树。
if(isinstance(value,int) or isinstance(value,float))
这个if是为了判断是数值类型还是字符串类型,若是数值类型则比对大小作为切分标准,字符串则比对是否相等。
2.4 buildDecisionTree()
def buildDecisionTree(rows,evaluationFunction=gini):
#递归建立决策树,当gain = 0 时停止递归
#bulid decision tree by recursive function
#stop recursive function when gain = 0
#return tree
currentGain = evaluationFunction(rows)
column_length = len(rows[0])
rows_length = len(rows)
best_gain = 0.0
best_value = None
best_set = None
#choose the best gain
for col in range(column_length-1):
col_value_set = set([x[col] for x in rows])
for value in col_value_set:
list1,list2 = splitDatas(rows,value,col)
p = len(list1)/rows_length
gain = currentGain - p * evaluationFunction(list1) - (1-p) * evaluationFunction(list2)
if gain > best_gain:
best_gain = gain
best_value = (col,value)
best_set = (list1,list2)
dcY = {'impurity' : '%.3f' % currentGain, 'samples' : '%d' % rows_length}
#stop or not stop
if best_gain > 0:
trueBranch = buildDecisionTree(best_set[0],evaluationFunction)
falseBranch = buildDecisionTree(best_set[1],evaluationFunction)
return Tree(col=best_value[0],value=best_value[1],trueBranch=trueBranch,falseBranch=falseBranch,summary=dcY)
else:
return Tree(results=calculateDiffCount(rows),summary=dcY,data=rows)
这是决策树中最关键的一个函数,我来简单讲解一下其中需要注意的实现细节。
for col in range(column_length-1):
这个循环是选取X中的一个特征,假设X是花的数据集,花的特征是:花瓣高度,花瓣宽度,花颜色。
则这个循环就是循环选取花瓣高度,花瓣宽度,花颜色。
col_value_set = set([x[col] for x in rows])
例如花瓣高度(x[col])为['15', '11', '12', '12', '11'],则返回的是['15', '11', '12']
for value in col_value_set
这个循环就是遍历上面得到的['15','11','12']
通过上述三个步骤遍历整个数据集的所有特征值,进行切分,以得到最优的gain值。
最后当得到best_gain时就继续递归调用buildDecisionTree()以生成整个决策树,当best_gain = 0 时说明已经不能再切分了,这时候停止就得到了决策树。
2.5 prune()
def prune(tree,miniGain,evaluationFunction=gini):
#剪枝, when gain < mini Gain,合并(merge the trueBranch and the falseBranch)
if tree.trueBranch.results == None:prune(tree.trueBranch,miniGain,evaluationFunction)
if tree.falseBranch.results == None:prune(tree.falseBranch,miniGain,evaluationFunction)
if tree.trueBranch.results != None and tree.falseBranch.results != None:
len1 = len(tree.trueBranch.data)
len2 = len(tree.falseBranch.data)
len3 = len(tree.trueBranch.data + tree.falseBranch.data)
p = float(len1)/(len1 + len2)
gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) - p * evaluationFunction(tree.trueBranch.data) - (1 - p) * evaluationFunction(tree.falseBranch.data)
if(gain < miniGain):
tree.data = tree.trueBranch.data + tree.falseBranch.data
tree.results = calculateDiffCount(tree.data)
tree.trueBranch = None
tree.falseBranch = None
建树后剪枝,是一个正则化的过程,当节点的gain小于给定的 mini Gain时则合并这两个节点
2.6 loadCSV(),plot(),dotgraph()
这三个函数负责读取数据,与画出树的形状,与算法本身没有很大的关系。
这是画出后树的形状:
3. 源码地址:
下面我会附上源码的地址,大家可以直接运行就可以看到效果了
此文章为记录自己一路的学习路程,也希望能给广大初学者们一点点帮助,如有错误,疑惑欢迎一起交流。