数据选取和数据情况
本次实验选取鸢尾花数据集(http://archive.ics.uci.edu/ml/datasets/Iris)
数据包含5列,分别是花萼长度、花萼宽度、花瓣长度、花瓣宽度、鸢尾花种类。
鸢尾花属种类包含三种:iris-setosa, iris-versicolour, iris-virginica。
每一类分别是50条,共150条数据;每一类在四个属性的分布情况如下图所示
可视化代码
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('seaborn')
import seaborn as sns
data = pd.read_table("iris.txt",sep=',',header=None)
data.columns=['sepal_len','sepal_wid','petal_len','petal_wid','label']
# 可视化
sns.set_style("whitegrid")
antV = ['#1890FF', '#2FC25B', '#FACC14', '#223273', '#8543E0', '#13C2C2', '#3436c7', '#F04864']
# 绘制 Violinplot
f, axes = plt.subplots(2, 2, figsize=(8, 8), sharex=True)
sns.despine(left=True)
sns.violinplot(x='label', y='sepal_len', data=data, palette=antV, ax=axes[0, 0])
sns.violinplot(x='label', y='sepal_wid', data=data, palette=antV, ax=axes[0, 1])
sns.violinplot(x='label', y='petal_len', data=data, palette=antV, ax=axes[1, 0])
sns.violinplot(x='label', y='petal_wid', data=data, palette=antV, ax=axes[1, 1])
plt.savefig('analysis.png',format='png', dpi=300)
plt.show()
利用C4.5算法分类
离散化连续变量
因为花萼长度、花萼宽度、花瓣长度、花瓣宽度均为连续变量,所以需要进行离散化处理;这里通过Gini Index来进行离散化处理,考虑到此次分三类,且通过上面的可视化,三种花在4个属性上分布均存在较大差异,所以对花萼长度、花萼宽度、花瓣长度、花瓣宽度四个属性均采用两个分界点来分成三类。
实现计算Gini Index寻找最优分界点的python代码见附件select_bestpoints.py,计算得到:
花萼长度的最优分界点为5.4,6.1,此时Gini Index=0.3793
花萼宽度的最优分界点为2.9,3.3,此时Gini Index=0.5097
花瓣长度的最优分界点为1.9,4.7,此时Gini Index= 0.0843
花瓣宽度的最优分界点为0.8,1.7,此时Gini Index= 0.0735
计算得到的分界结果与每一类在四个属性的分布图也比较吻合
根据计算得到的最优分界点,在鸢尾花数据中应用分类:
data = pd.read_table(datafile,sep=',',header=None)
data.columns=['sepal_len','sepal_wid','petal_len','petal_wid','label']
def trans(x,n1,n2):
if x<=n1:
s=1
elif (x>n1)&(x<=n2):
s=2
else:
s=3
return s
data['sepal_len'] = data['sepal_len'].apply(lambda x : trans(x,5.4,6.1))
data['sepal_wid'] = data['sepal_wid'].apply(lambda x : trans(x,2.9,3.3))
data['petal_len'] = data['petal_len'].apply(lambda x : trans(x,1.9,4.7))
data['petal_wid'] = data['petal_wid'].apply(lambda x : trans(x,0.8,1.7))
属性 | 1 | 2 | 3 |
---|---|---|---|
花萼长度 | x<=5.4 | 5.4<x<=6.1 | x>6.1 |
花萼宽度 | x<=2.9 | 2.9<x<=3.3 | x>3.3 |
花瓣长度 | x<=1.9 | 1.9<x<=4.7 | x>4.7 |
花瓣宽度 | x<=0.8 | 0.8<x<=1.7 | x>1.7 |
C4.5原理
ID3算法的核心是在决策树各个结点上对应信息增益准则选择特征,递归地构建决策树。具体方法是:从根结点(root node)开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子节点;再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增益均很小或没有特征可以选择为止。最后得到一个决策树。
但是ID3采用信息增益来选择特征,存在一个缺点,它一般会优先选择有较多属性值的特征,因为属性值多的特征会有相对较大的信息增益。为了避免ID3的不足,C4.5中是用**信息增益率(gain ratio)**来作为选择分支的准则。对于有较多属性值的特征,信息增益率的分母Split information(S,A),我们称之为分裂信息,会稀释掉它对特征选择的影响。
即ID3计算信息增益(GAIN)来选择特征,而C4.5中是用信息增益率(GainRATIO)
C4.5实现
递归创建决策树时,递归有两个终止条件:第一个停止条件是该处所有数据标签完全相同,则直接返回该类标签;第二个停止条件是使用完了所有特征,仍然不能将数据划分仅包含唯一类别的分组,即决策树构建失败,特征不够用。由于第二个停止条件无法简单地返回唯一的类标签,这里挑选出现数量最多的类别作为返回值。具体实现代码见附件C4.5.py
def createTree(data,labels,featLabels):
'''
建立决策树
'''
classList = [rowdata[-1] for rowdata in data] # 取每一行的最后一列,分类结果(1/0)
if classList.count(classList[0])==len(classList):
return classList[0]
if len(data[0])==1 or len(labels)==0:
return majorityCnt(classList)
bestFeat = BestSplit(data) #根据信息增益选择最优特征
if bestFeat==-1:
return majorityCnt(classList)
bestLab = labels[bestFeat]
featLabels.append(bestLab)
myTree = {bestLab:{}} #分类结果以字典形式保存
del(labels[bestFeat])
featValues = [rowdata[bestFeat] for rowdata in data]
uniqueVals =set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestLab][value] = createTree(splitData(data,bestFeat,value),subLabels,featLabels)
return myTree
运行如下图所示:
进一步将分类结果可视化如下图所示:
(使用可视化时需要把附件中的treePlotter.py放置python库路径下)
上图分支中的1,2,3分别对应于离散化处理时,花萼长度、花萼宽度、花瓣长度、花瓣宽度对应的分割区间,即:
属性 | 1 | 2 | 3 |
---|---|---|---|
花萼长度 | x<=5.4 | 5.4<x<=6.1 | x>6.1 |
花萼宽度 | x<=2.9 | 2.9<x<=3.3 | x>3.3 |
花瓣长度 | x<=1.9 | 1.9<x<=4.7 | x>4.7 |
花瓣宽度 | x<=0.8 | 0.8<x<=1.7 | x>1.7 |
同时注意到我们的决策树是略有问题的,下图划线处均为同一种,却有分支。这里分析是因为鸢尾花数据集本来是连续性数据,这里强行离散化处理,并不能很好的进行区分,所以在这些分支,**每一个分支下对应的数据均没有把数据完全分开,即未达到递归的第一个终止条件,达到了第二个条件,而且在每个分支中最大的都是同一种标签,就出现了这种情况。**后期需要针对这种情况,进一步改进优化。
拿构建的简单数据测试编写的算法如下图,发现是正常的,验证了上述想法
为了方便使用决策树进行分类,可以定义一个 保存决策树的函数
def storeTree(inputTree, filename):
'''
保存决策树
:param inputTree:
:param filename:
:return:
'''
with open(filename, 'wb') as fw:
pickle.dump(inputTree, fw)
应用训练好的决策树分类
首先定义一个读取决策树的函数
def grabTree(tree_file,fea_file):
'''
读取决策树和特征属性
:param filename:
:return:
'''
mytree=json.load(open(tree_file,'r+'))
featLabels=json.load(open(fea_file,'r+'))
return mytree,featLabels
前面训练时有个featLabels参数。它是用来干什么的?它就是用来记录各个分类结点的,在用决策树做预测的时候,我们按顺序输入需要的分类结点的属性值即可。因为是决策树保存的形式是嵌套字典,可以按特征属性顺序依次取出,得到分类标签。应用决策树的完整代码见附件apply_C4.5.py。
for i in range(0,len(featLabels)):
try:
myTree=myTree[featLabels[i]][testvalue[i]]
except:
print("这个测试结果是:")
print(myTree)
break
运行结果如下图所示
学习链接:https://cuijiahua.com/blog/2017/11/ml_3_decision_tree_2.html
本文全部代码可以关注下方博主公众号获取,后台回复:鸢尾花