一、决策树
1.概念
决策树在现实生活中应用广泛,也非常容易理解,通过构建一颗决策树,只要根据树的的判断条件不断地进行下去,最终就会返回一个结果。例如下图所示。决策树天然地可以解决多分类问题,同时也可以应用于回归问题中。
现在先通过sklearn中封装的决策树方法对数据进行分类,来学习决策树。
- import numpy as np
- import matplotlib.pyplot as plt
- from sklearn import datasets
- iris = datasets.load_iris()
- x=iris.data[:,2:]#这里就用了2个特征
- y=iris.target
- plt.scatter(x[y==0,0],x[y==0,1])
- plt.scatter(x[y==1,0],x[y==1,1])
- plt.scatter(x[y==2,0],x[y==2,1])
- plt.show()
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
iris = datasets.load_iris()
x=iris.data[:,2:]#这里就用了2个特征
y=iris.target
plt.scatter(x[y==0,0],x[y==0,1])
plt.scatter(x[y==1,0],x[y==1,1])
plt.scatter(x[y==2,0],x[y==2,1])
plt.show()
- from sklearn.tree import DecisionTreeClassifier
- dt = DecisionTreeClassifier(max_depth=2,criterion=“entropy”) #决策树深度2
- dt.fit(x,y)
- def plot_decision_boundary(model,axis):
- x0,x1 = np.meshgrid(
- np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),
- np.linspace(axis[2],axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1)
- )
- x_new = np.c_[x0.ravel(),x1.ravel()]
- y_predict = model.predict(x_new)
- zz = y_predict.reshape(x0.shape)
- from matplotlib.colors import ListedColormap
- custom_cmap = ListedColormap([‘#EF9A9A’,’#FFF59D’,’#90CAF9’])
- plt.contourf(x0,x1,zz,linewidth =5,cmap=custom_cmap)
- plot_decision_boundary(dt,axis=[0.5,7.5,0,3])
- plt.scatter(x[y==0,0],x[y==0,1])
- plt.scatter(x[y==1,0],x[y==1,1])
- plt.scatter(x[y==2,0],x[y==2,1])
- plt.show()
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier(max_depth=2,criterion="entropy") #决策树深度2
dt.fit(x,y)
def plot_decision_boundary(model,axis):
x0,x1 = np.meshgrid(
np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),
np.linspace(axis[2],axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1)
)
x_new = np.c_[x0.ravel(),x1.ravel()]
y_predict = model.predict(x_new)
zz = y_predict.reshape(x0.shape)
from matplotlib.colors import ListedColormap
custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
plt.contourf(x0,x1,zz,linewidth =5,cmap=custom_cmap)
plot_decision_boundary(dt,axis=[0.5,7.5,0,3])
plt.scatter(x[y==0,0],x[y==0,1])
plt.scatter(x[y==1,0],x[y==1,1])
plt.scatter(x[y==2,0],x[y==2,1])
plt.show()
得出来的决策边界可以绘制出如下右图所示的决策树,当数据小于2.4时就分为A,如果大于2.4就继续考察,当这部分的样本y小于1.8时就分为B,如果大于1.8就分为C。这个决策树总共的深度为2,即有2层判断条件。但是这个决策依据是如何得出来的?又该在哪个维度哪个值进行划分?下面介绍一个重要的概念-信息熵。
2.信息熵
熵在信息论中代表随机变量不确定度的度量,由香农提出来的。熵越大,数据的不确定性越高,数据越混乱;熵越小,数据的不确定性越低,数据越趋向于集中统一。计算公式如下,k代表类别,pi代表这个类别所对应的概率是多少。
对于一个数据集,假设其各个类别所对应的比重都为1/3,则H=-1/3log(1/3)-1/3log(1/3)-1/3log(1/3)=1.0986。
对于一个数据集,假设其各个类别所对应的比重分别为{1/10,2/10,7/10},则H=-1/10log(1/10)- 2/10log(2/10)- 7/10log(7/10) =0.8018。
这个时候,我们可以说第二个数据集比第一个数据集更确定。容易解释,在第二个数据集中,大部分的数据都能够确定在7/10所对应的数据中,因此可以说更加的确定,混乱度越低。更极端的,如果一个数据集对应的比重为{1,0,0},那么H=0,数据的混乱度为0,能够直接确定数据在第一个类别中。
决策树的划分依据:选取某个特征,样本经过该特征进行分类后,得到的几个子集有最低的信息熵,它相对于原来的数据集就有最大的信息增益(信息增益=原来样本信息熵-划分后样本的信息熵)。接着在得到的子集上再通过信息熵来进一步选择某特征,使得继续划分得到的子集有最低的信息熵。以此类推,直到树的深度满足要求,或者当前树的的节点数据已经能够完全确定了。
举例如图所示,对于判断一个人是否会购买物品的的决策树,原始数据集有年龄、信誉和是否为学生这3个特征,假设原始的数据集对应的信息熵为H0。现在用年龄划分为{青年,中年,老年}与用其他2个特征划分得到的结果相比,拥有最低的信息熵,即信息增益最大,那么第一个划分的依据就选择年龄。接下来,划分得到的子数据集继续考察除了年龄外的其他特征,对于青年,发现用是否为学生这个特征能够得到最低的信息熵,因此采用是否为学生这个特征;对于中年,发现样本都确定为某一类,所以就不用继续划分;对于老年,发现用信誉这个特征能够得到最低的信息熵,因此采用信誉这个特征。这样就得到了最终的决策树。
3.ID3算法代码实现
为了模拟,我们先创建一个数据集,该数据集有2个特征,分别代表是否可以浮出水面、是否有脚蹼。用1代表是,0代表否。最后一列数据表示当前样本是否为鱼类。【该代码和例子取自机器学习实战】
- def createDataSet():
- dataSet = [[1, 1, ‘yes’], #例如这个样本点代表不能浮出水面、有脚蹼,是鱼类
- [1, 1, ‘yes’],
- [1, 0, ‘no’],
- [0, 1, ‘no’],
- [0, 1, ‘no’]]
- labels = [‘no surfacing’,’flippers’] #label记录样本的特征名称
- return dataSet, labels
def createDataSet():
dataSet = [[1, 1, 'yes'], #例如这个样本点代表不能浮出水面、有脚蹼,是鱼类
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers'] #label记录样本的特征名称
return dataSet, labels
输入数据集,计算该数据集所对应的信息熵的值。
- from math import log
- def calcShannonEnt(dataSet): #输入数据集,计算信息熵
- numEntries = len(dataSet) #计算有多少个样本
- labelCounts = {} #创建一个字典,用于保存样本的标签,以及该标签对应的数量
- for featVec in dataSet: #遍历所有样本
- currentLabel = featVec[-1] #将每个样本的最后一列,即标签取出
- if currentLabel not in labelCounts.keys(): #如果该标签不在字典中,就加入该标签,并且将数目置为0