决策树--从原理到实现

本文详细介绍了决策树的基本概念、信息论原理、不同算法(ID3、C4.5、CART)的区别,以及如何通过代码实现CART算法。文章结合实例,解释了决策树在分类和回归任务中的应用,并提供了实际代码供读者学习。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

================================================================================

算算有相当一段时间没写blog了,主要是这学期作业比较多,而且我也没怎么学新的东西

接下来打算实现一个小的toy lib:DML,同时也回顾一下以前学到的东西

当然我只能保证代码的正确性,不能保证其效率啊~~~~~~

之后我会陆续添加进去很多代码,可以供大家学习的时候看,实际使用还是用其它的吧大笑

================================================================================

一.引入

决策树基本上是每一本机器学习入门书籍必讲的东西,其决策过程和平时我们的思维很相似,所以非常好理解,同时有一堆信息论的东西在里面,也算是一个入门应用,决策树也有回归和分类,但一般来说我们主要讲的是分类,方便理解嘛。

虽然说这是一个很简单的算法,但其实现其实还是有些烦人,因为其feature既有离散的,也有连续的,实现的时候要稍加注意

          (不同特征的决策,图片来自【1】)

O-信息论的一些point:

             然后加入一个叫信息增益的东西:
             □.信息增益:(information gain)
                                 g(D,A) = H(D)-H(D|A)
                                 表示了特征A使得数据集D的分类不确定性减少的程度
             □.信息增益比:(information gain ratio)
                                  g‘(D,A)=g(D,A) / H(D)
             □.基尼指数:
             
                        
                         

二.各种算法

1.ID3

ID3算法就是对各个feature信息计算信息增益,然后选择信息增益最大的feature作为决策点将数据分成两部分

                然后再对这两部分分别生成决策树。

                 图自【1】

       

2.C4.5

                C4.5与ID3相比其实就是用信息增益比代替信息增益,应为信息增益有一个缺点:

                       信息增益选择属性时偏向选择取值多的属性

                算法的整体过程其实与ID3差异不大:图自【2】

                 

3.CART

CART(classification and regression tree)的算法整体过程和上面的差异不大,然是CART的决策是二叉树的

每一个决策只能是“是”和“否”,换句话说,即使一个feature有多个可能取值,也只选择其中一个而把数据分类

两部分而不是多个,这里我们主要讲一下分类树,它用到的是基尼指数:

图自【2】

三.代码及实现

                  好吧,其实我就想贴贴代码而已……本代码在https://github.com/justdark/dml/tree/master/dml/DT

                  纯属toy~~~~~实现的CART算法:

                 

from __future__ import division
import numpy as np
import scipy as sp
import pylab as py
def pGini(y):
		ty=y.reshape(-1,).tolist()
		label = set(ty)
		sum=0
		num_case=y.shape[0]
		#print y
		for i in label:
			sum+=(np.count_nonzero(y==i)/num_case)**2
		return 1-sum
	
class DTC:
	def __init__(self,X,y,property=None):
		'''
			this is the class of Decision Tree
			X is a M*N array where M stands for the training case number
								   N is the number of features
			y is a M*1 vector
			property is a binary vector of size N
				property[i]==0 means the the i-th feature is discrete feature,otherwise it's continuous
				in default,all feature is discrete
				
		'''
		'''
			I meet some problem here,because the ndarry can only have one type
			so If your X have some string parameter,all thing will translate to string
			in this situation,you can't have continuous parameter
			so remember:
			if you have continous parameter,DON'T PUT any STRING IN X  !!!!!!!!
		'''
		self.X=np.array(X)
		self.y=np.array(y)
		self.feature_dict={}
		self.labels,self.y=np.unique(y,return_inverse=True)
		self.DT=list()
		if (property==None):
			self.property=np.zeros((self.X.shape[1],1))
		else:
			self.property=property
			
		for i in range(self.X.shape[1]):
			self.feature_dict.setdefault(i)
			self.feature_dict[i]=np.unique(X[:,i])

		if (X.shape[0] != y.shape[0] ):
			print "the shape of X and y is not right"
			
		for i in range(self.X.shape[1]):
			for j in self.feature_dict[i]:
				pass#print self.Gini(X,y,i,j)
		pass

	def Gini(self,X,y,k,k_v):
		if (self.property[k]==0):
			#print X[X[:,k]==k_v],'dasasdasdasd'
			#print X[:,k]!=k_v
			c1 = (X[X[:,k]==k_v]).shape[0]
			c2 = (X[X[:,k]!=k_v]).shape[0]
			D = y.shape[0]
			return c1*pGini(y[X[:,k]==k_v])/D+c2*pGini(y[X[:,k]!=k_v])/D
		else:
			c1 = (X[X[:,k]>=k_v]).shape[0]
			c2 = (X[X[:,k]<k_v]).shape[0]
			D = y.shape[0]
			#print c1,c2,D
			return c1*pGini(y[X[:,k]>=k_v])/D+c2*pGini(y[X[:,k]<k_v])/D
		pass
	def makeTree(self,X,y):
		min=10000.0
		m_i,m_j=0,0
		if (np.unique(y).size<=1):

			return (self.labels[y[0]])
		for i in range(self.X.shape[1]):
			for j in self.feature_dict[i]:
				p=self.Gini(X,y,i,j)
				if (p<min):
					min=p
					m_i,m_j=i,j
		
		

		if (min==1):
			return (y[0])
		left=[]
		righy=[]
		if (self.property[m_i]==0):
			left = self.makeTree(X[X[:,m_i]==m_j],y[X[:,m_i]==m_j])
			right = self.makeTree(X[X[:,m_i]!=m_j],y[X[:,m_i]!=m_j])
		else :
			left = self.makeTree(X[X[:,m_i]>=m_j],y[X[:,m_i]>=m_j])
			right = self.makeTree(X[X[:,m_i]<m_j],y[X[:,m_i]<m_j])
		return [(m_i,m_j),left,right]
	def train(self):
		self.DT=self.makeTree(self.X,self.y)
		print self.DT
		
	def pred(self,X):
		X=np.array(X)
		  
		result = np.zeros((X.shape[0],1))
		for i in range(X.shape[0]):
			tp=self.DT
			while ( type(tp) is  list):
				a,b=tp[0]
				
				if (self.property[a]==0):
					if (X[i][a]==b):
						tp=tp[1]
					else:
						tp=tp[2]
				else:
					if (X[i][a]>=b):
						tp=tp[1]
					else:
						tp=tp[2]
			result[i]=self.labels[tp]
		return result
		pass
	

               这个maketree让我想起了线段树………………代码里的变量基本都有说明

试验代码:

  

from __future__ import division
import numpy as np
import scipy as sp
from dml.DT import DTC
X=np.array([
[0,0,0,0,8],
[0,0,0,1,3.5],
[0,1,0,1,3.5],
[0,1,1,0,3.5],
[0,0,0,0,3.5],
[1,0,0,0,3.5],
[1,0,0,1,3.5],
[1,1,1,1,2],
[1,0,1,2,3.5],
[1,0,1,2,3.5],
[2,0,1,2,3.5],
[2,0,1,1,3.5],
[2,1,0,1,3.5],
[2,1,0,2,3.5],
[2,0,0,0,10],
])


y=np.array([
[1],
[0],
[1],
[1],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
])
prop=np.zeros((5,1))
prop[4]=1
a=DTC(X,y,prop)
a.train()
print a.pred([[0,0,0,0,3.0],[2,1,0,1,2]])

可以看到可以学习出一个决策树:


展示出来大概是这样:注意第四个参数是连续变量



               


四.reference

           【1】:《机器学习》 -mitchell,卡耐基梅龙大学
           【2】:《统计学习方法》-李航
nside君的MySQL网络培训班课程特点: 业界最权威的MySQL数据库培训师姜承尧老师(也就是Inside君本人啦)亲授.姜承尧老师出版了《MySQL技术内幕:InnoDB存储引擎》、《MySQL内核:InnoDB存储引擎》等Mysql书籍。 课程紧密结合互联网公司实践,学员能够领略到BAT、网易等大公司的数据库架构与应用案例 课纲结合最新的MySQL 5.6、5.7版本,使得学员学到的都是最新的内容 充分掌握课程内容的学员年薪至少在25W起,第1期的学员已经证明了培训的价值 优秀学员可以获得姜老师的BAT等大型互联网公司的内推 面试技巧与简历模板(新增),帮助学员拿到更好的offer MySQL 安装与引擎 day001-MySQL 5.7介绍和安装 day002-MySQL 5.7安装多实例 day003-MySQL升级 参数 连接 权限 day004-MySQL权限拾 遗Role模拟 Workbench 体系结构 day005-slow_log generic_log audit 存储引擎一 day006-存储引擎二 多实例安装上 day007-MySQL 多实例下 SSL MySQL 数据类型和SQL查询 开发 day008-MySQL 数据类型 day009-精通JSON类型 day010-Employees 临时表的创建 外键约束 day011-SQL语法之SELECT day012-子查询 INSERT UPDATE DELETE REPLACE day013-作业讲解一 Rank 视图 UNION 触发器上 day014-触发器下 存储过程 自定义函数 MySQL 执行计划与优化器 day015-索引 B+树 上 day016-索引 B+树 下 Explain 1 day017-Explain 2 MySQL innodb引擎优化 day018-磁盘 day019-磁盘测试 day020-InnoDB_1 表空间 General day021-InnoDB_2 SpaceID.PageNumber 压缩表) day022-InnoDB_3 透明表空间压缩 索引组织表 day023-InnoDB_4 页(2) 行记录 day024-InnoDB_5 – heap_number Buffer Poo day025-InnoDB_6 Buffer Pool与压缩页 CheckPoint LSN day026-InnoDB_7 doublewrite ChangeBuffer AHI FNP MySQL 索引与innodb锁机制 day027-Secondary Index day028-join算法锁_1 day029-锁_2 day030-锁_3 day031-锁_4 day032-锁_5 day032-锁5标清 day033-锁_6 事物_1 day033-锁_6 事物1标清 day034-事物_2 MySQL 性能衡量 day035-redo_binlog_xa day036-undo_sysbench day036-undosysbench标清 day037-tpcc_mysqlslap MySQL 备份与恢复 day038-purge死锁举例_MySQL backup备份_1 day039-MySQL backup备份恢复_2 MySQL 复制技术与高可用 day040-MySQL 备份恢复backup_3_replication_1 day041-backup_4-replication_2 day042-replication_3 day043-replication_4-GTID 1 day044-replication_5-GTID 2
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值