详细的决策树(C4.5, ID3, CART)介绍和公式在前面的博文:决策树详解
个人觉得李航博士的<统计学习方法>对机器学习一些算法解释的也很好.
详细的代码可参考知乎:知乎
本文主要是对CART决策树的实现细节做代码展示,代码有比较详细的注解.
数据集有四个特征,最后一个为label–(SepalLength, SepalWidth, PetalLength, PetalWidth, Name),总共150条样本,每个类别150条,Name为(setosa, versicolor, virginica)三类.数据集大概如下:
先附上一张最后决策树的图片(代码没有实现这个功能).
具体代码如下:
# -*- coding: utf-8 -*-
import numpy as np
class Tree:
def __init__(self, value=None, trueBranch=None, falseBranch=None, results=None, col=-1, data=None):
self.value = value
self.trueBranch = trueBranch
self.falseBranch = falseBranch
self.results = results
self.col = col
self.data = data
def __str__(self):
print(self.col, self.value)
print(self.results)
return ""
def calculateDiffCount(datas):
"""
该函数是计算gini值的辅助函数,假设输入的dataSet为为['A', 'B', 'C', 'A', 'A', 'D'],
则输出为['A':3,' B':1, 'C':1, 'D':1],这样分类统计dataSet中每个类别的数量
"""
# 计算个全部样本属于的每个类别的个数.返回一个字典{类别1:个数, 类别2:个数,.....} 其实就是计算Ck
results = {}
for data in datas:
# 按行来读取数据,data[-1]表示的就是y值,y值就是分类的类别取值
# data[-1] means dataType
if data[-1] not in results:
results.setdefault(data[-1], 1)
else:
results[data[-1]] += 1
return results
def gini(rows):
# 计算gini的值(Calculate GINI)
# 每个类别的个数/样本个数
length = len(rows)
results = calculateDiffCount(rows)
imp = 0.0
for i in results:
imp += results[i] / length * results[i] / length
return 1 - imp
def splitDatas(rows, value, column):
# 根据条件分离数据集(splitDatas by value, column)
# return 2 part(list1, list2)
# rows就是数据(D),column表示特征(A),value表示取值标准(A=?/value)
list1 = []
list2 = []
if isinstance(value, int) or isinstance(value, float):
for row in rows:
if row[column] >= value:
list1.append(row)
else:
list2.append(row)
else:
for row in rows:
if row[column] == value:
list1.append(row)
else:
list2.append(row)
return list1, list2
def buildDecisionTree(rows, evaluationFunction=gini):
# 递归建立决策树, 当gain=0,时停止递归
# build decision tree bu recursive function
# stop recursive function when gain = 0
# return tree
currentGain = evaluationFunction(rows)
column_lenght = len(rows[0])
rows_length = len(rows)
best_gain = 0.0
best_value = None
best_set = None
# choose the best gain
for col in range(column_lenght - 1):
col_value_set = set([x[col] for x in rows])
for value in col_value_set:
list1, list2 = splitDatas(rows, value, col)
p = len(list1) / rows_length
gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
if gain > best_gain:
best_gain = gain
best_value = (col, value)
best_set = (list1, list2)
# stop or not stop
if best_gain > 0:
trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
return Tree(col=best_value[0], value=best_value[1], trueBranch=trueBranch, falseBranch=falseBranch)
else:
return Tree(results=calculateDiffCount(rows), data=rows)
def classify(data, tree):
if tree.results != None:
return tree.results
else:
branch = None
v = data[tree.col]
if isinstance(v, int) or isinstance(v, float):
if v >= tree.value:
branch = tree.trueBranch
else:
branch = tree.falseBranch
else:
if v == tree.value:
branch = tree.trueBranch
else:
branch = tree.falseBranch
return classify(data, branch)
def loadCSV():
def convertTypes(s):
s = s.strip()
try:
return float(s) if '.' in s else int(s)
except ValueError:
return s
data = np.loadtxt("fishiris.csv", dtype='str', delimiter=',')
data = data[1:, :]
dataSet =([[convertTypes(item) for item in row] for row in data])
return dataSet
if __name__ == '__main__':
dataSet = loadCSV()
decisionTree = buildDecisionTree(dataSet, evaluationFunction=gini)
test_data = [4.9,3,1.4,0.2]
r = classify(test_data, decisionTree)
print(r)
预测结果如下: