使用 python 手写 决策树

使用 python 自己写一个 决策树

很多复杂的学习方法,明白了其基础之后,一切就变得简单、易懂,并且符合直觉。
我今天打算手写一个决策树,或者说是“分类回归树”。

参考

https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/

决策树是一种强大的预测方法,在工业界的数据分析和决策中非常好用。
它受欢迎的一个主要原因是,它最终的模型就像给你写出了一堆 if else 的判断条件。
这样,无论通过模型获得怎样的结果,都可以通过判断条件进行反推。
因为这个优势的存在,决策树就可以帮助我们这些从业人员,更好地分析数据。

同时决策树还是其它先进的集成学习方法(advanced ensemble methods)的基石。
如,bagging, random forests 和 gradient boosting。

待解决的几个问题

  • 如何评估分裂点。知道如何评估之后,就可以找到它了。
  • How to arrange splits into a decision tree structure.
  • 怎么在一个实际问题中使用决策树

简介

Classification and Regression Tree(CART) 是对决策树的一个比较现代的叫法。
决策树用到的模型是二叉树这种数据结构,每个结点有 0 个,或者 1 个,或者 2 个子结点。

每个结点,代表了一个输入变量和这个变量的分割点,叶结点包含一个输出。

决策树一旦创建,新的数据就可以通过根结点,经过各种判断条件,最终到叶结点,并把叶结点表示的结果作为输出。

分割时的损失函数

  • 回归模型的损失函数: 使用最小方差作为损失函数
  • 分类模型损失函数: 使用 Gini 损失函数,Gini Score 是用来衡量一个子结点中包含类型的混乱程度的,越混乱 Gini Score 越大,分类效果自然是就越差。

用来做测试的数据 Banknote Dataset

Banknote Dataset 给出了一些用来描述钞票图片的数据,并用这些数据对钞票的真伪进行判断。

这个数据集饮食 1372 个样本,每行有 5 个数值特征。
这个任务是一个二分类任务。

特征描述如下:

1. 图像的小波变换方差(连续数值)
2. 图像的小波变换偏度(skewness)(连续数值)
3. 图像的小波变换峭度(kurtosis)(连续数值)
4. 图像的熵(连续数值)
5. 样本的类别(整数)

数据示例如下:

3.6216,8.6661,-2.8073,-0.44699,0
4.5459,8.1674,-2.4586,-1.4621,0
3.866,-2.6383,1.9242,0.10645,0
3.4566,9.5228,-4.0112,-3.5944,0
0.32924,-4.4552,4.5718,-0.9888,0
4.3684,9.6718,-3.9606,-3.1625,0

未理解问题
Zero Rule Algorithm

Using the Zero Rule Algorithm to predict the most common class value, the baseline accuracy on the problem is about 50%.
这个意思好像是说,就按大多数的类型作为全部的类别,进而得到分类正确率。并用它作为 baseline.

下载

banknote authentication Data Set

banknote_authentication = './data_banknote_authentication.csv'

目录

下面分五部分分别实现

  1. Gini Index
  2. Create Split
  3. Build a Tree
  4. Make a Prediction
  5. Banknote Case Study

Gini Index

def gini_score(groups, classes):
    '''
    row = [col1, col2, col3, col4, class]
    group: [row, row, ..., row]
    groups: [group, group, ..., group]
        
    classes: [0, 1]
    '''
    # weight = sum(group) / sum(sum(group))
    # Gini index = sum(sum(one_class) / sum(group))
    instances_num = sum([len(group) for group in groups])
    gini_score = 0.0
    for group in groups:
        group_num = len(group)
        if group_num == 0:
            continue
        gini_index = 1.0
        for c in classes:
            c_num = 0
            for i in group:
                if c == i[-1]:
                    c_num += 1
            c_p = c_num / group_num # group_num == 0?
            gini_index = gini_index - c_p * c_p
        group_gini_score = gini_index * group_num / instances_num
        gini_score += group_gini_score
    return gini_score
row1 = [0,0,0,0,1]
row2 = [0,0,0,0,0]
group1 = [row1, row2]
row3 = [0,0,0,0,0]
row4 = [0,0,0,0,1]
row5 = [0,0,0,0,1]
group2 = [row3, row4, row5]

test_data_for_gini_index = [group1, group2]

classes = [0, 1]
gini_score(test_data_for_gini_index, classes)
0.4666666666666667

博客里面的函数对比测试

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

# test Gini values
groups1 = [[[1, 1], [1, 0]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值