使用 python 自己写一个 决策树
很多复杂的学习方法,明白了其基础之后,一切就变得简单、易懂,并且符合直觉。
我今天打算手写一个决策树,或者说是“分类回归树”。
参考
https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/
决策树是一种强大的预测方法,在工业界的数据分析和决策中非常好用。
它受欢迎的一个主要原因是,它最终的模型就像给你写出了一堆 if else
的判断条件。
这样,无论通过模型获得怎样的结果,都可以通过判断条件进行反推。
因为这个优势的存在,决策树就可以帮助我们这些从业人员,更好地分析数据。
同时决策树还是其它先进的集成学习方法(advanced ensemble methods)的基石。
如,bagging, random forests 和 gradient boosting。
待解决的几个问题
- 如何评估分裂点。知道如何评估之后,就可以找到它了。
- How to arrange splits into a decision tree structure.
- 怎么在一个实际问题中使用决策树
简介
Classification and Regression Tree(CART) 是对决策树的一个比较现代的叫法。
决策树用到的模型是二叉树这种数据结构,每个结点有 0 个,或者 1 个,或者 2 个子结点。
每个结点,代表了一个输入变量和这个变量的分割点,叶结点包含一个输出。
决策树一旦创建,新的数据就可以通过根结点,经过各种判断条件,最终到叶结点,并把叶结点表示的结果作为输出。
分割时的损失函数
- 回归模型的损失函数: 使用最小方差作为损失函数
- 分类模型损失函数: 使用 Gini 损失函数,Gini Score 是用来衡量一个子结点中包含类型的混乱程度的,越混乱 Gini Score 越大,分类效果自然是就越差。
用来做测试的数据 Banknote Dataset
Banknote Dataset 给出了一些用来描述钞票图片的数据,并用这些数据对钞票的真伪进行判断。
这个数据集饮食 1372 个样本,每行有 5 个数值特征。
这个任务是一个二分类任务。
特征描述如下:
1. 图像的小波变换方差(连续数值)
2. 图像的小波变换偏度(skewness)(连续数值)
3. 图像的小波变换峭度(kurtosis)(连续数值)
4. 图像的熵(连续数值)
5. 样本的类别(整数)
数据示例如下:
3.6216,8.6661,-2.8073,-0.44699,0
4.5459,8.1674,-2.4586,-1.4621,0
3.866,-2.6383,1.9242,0.10645,0
3.4566,9.5228,-4.0112,-3.5944,0
0.32924,-4.4552,4.5718,-0.9888,0
4.3684,9.6718,-3.9606,-3.1625,0
未理解问题
Zero Rule Algorithm
Using the Zero Rule Algorithm to predict the most common class value, the baseline accuracy on the problem is about 50%.
这个意思好像是说,就按大多数的类型作为全部的类别,进而得到分类正确率。并用它作为 baseline.
下载
banknote authentication Data Set
banknote_authentication = './data_banknote_authentication.csv'
目录
下面分五部分分别实现
- Gini Index
- Create Split
- Build a Tree
- Make a Prediction
- Banknote Case Study
Gini Index
def gini_score(groups, classes):
'''
row = [col1, col2, col3, col4, class]
group: [row, row, ..., row]
groups: [group, group, ..., group]
classes: [0, 1]
'''
# weight = sum(group) / sum(sum(group))
# Gini index = sum(sum(one_class) / sum(group))
instances_num = sum([len(group) for group in groups])
gini_score = 0.0
for group in groups:
group_num = len(group)
if group_num == 0:
continue
gini_index = 1.0
for c in classes:
c_num = 0
for i in group:
if c == i[-1]:
c_num += 1
c_p = c_num / group_num # group_num == 0?
gini_index = gini_index - c_p * c_p
group_gini_score = gini_index * group_num / instances_num
gini_score += group_gini_score
return gini_score
row1 = [0,0,0,0,1]
row2 = [0,0,0,0,0]
group1 = [row1, row2]
row3 = [0,0,0,0,0]
row4 = [0,0,0,0,1]
row5 = [0,0,0,0,1]
group2 = [row3, row4, row5]
test_data_for_gini_index = [group1, group2]
classes = [0, 1]
gini_score(test_data_for_gini_index, classes)
0.4666666666666667
博客里面的函数对比测试
# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
# count all samples at split point
n_instances = float(sum([len(group) for group in groups]))
# sum weighted Gini index for each group
gini = 0.0
for group in groups:
size = float(len(group))
# avoid divide by zero
if size == 0:
continue
score = 0.0
# score the group based on the score for each class
for class_val in classes:
p = [row[-1] for row in group].count(class_val) / size
score += p * p
# weight the group score by its relative size
gini += (1.0 - score) * (size / n_instances)
return gini
# test Gini values
groups1 = [[[1, 1], [1, 0]