学习打卡内容:
阅读《李航统计学习方法》中p55-p58页
总结决策树模型结构
理解决策树递归思想
阅读《李航统计学习》中p58-p63页
学习信息增益
学习信息增益率
阅读《李航统计学习》中p63-65页
学习ID3算法优缺点
学习C4.5算法优缺点
理解C4.5算法在ID3算法上有什么提升
学习C4.5算法在连续值上的处理
学习决策树如何生成
阅读《机器学习实战》中p37-p41页
划分数据集代码
选择最好的数据集划分方式代码
创建树的函数代码
要求:
- 参考机器学习实战,生成决策树
import pandas as pd
data=pd.read_csv('watermelon_3a.csv')
data.drop('Idx', axis=1, inplace=True)
信息增益方法
# 主要用了pandas方法来写
def calcShannonEnt(dataSet):
shannonEnt = 0.0
prob = dataSet['label'].value_counts()/len(dataSet)
shannonEnt -= (prob * np.log2(prob)).sum()
return shannonEnt
def splitDataSet(dataSet, feat, value):
data1 = dataSet[dataSet[feat]==value].copy()
data1.drop(feat, axis=1, inplace=True)
return data1
def splitcontinuousDataSet(dataSet, feat, value):
data1 = dataSet[dataSet[feat]<value].copy()
data1.drop(feat, axis=1, inplace=True)
data2 = dataSet[dataSet[feat]>value].copy()
data2.drop(feat, axis=1, inplace=True)
return data1, data2
def calccontinuousInfoGain(dataSet, feat):
'''
dataset_数据集
'''
baseEnt = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestMid=-1
ta = ((dataSet[feat].sort_values() + dataSet[feat].sort_values().shift(-1))/2).values[:-1]
for value in ta:
newEnt = 0.0
sublessDataSet, submoreDataSet = splitcontinuousDataSet(dataSet, feat, value)
probless = len(sublessDataSet) / float(len(dataSet))
probmore = len(submoreDataSet) / float(len(dataSet))
newEnt += (probless*calcShannonEnt(sublessDataSet) + probmore*calcShannonEnt(submoreDataSet))
infoGain = baseEnt - newEnt
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestMid = value
return bestInfoGain, bestMid
def chooseBestFeatureToSplit(dataSet):
'''
dataSet_数据集
'''
baseEnt = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = 'label'
for feat in dataSet.columns[:-1]:
newEnt = 0.0
if dataSet[feat].dtype == 'object':
for value in dataSet[feat].unique():
subDataSet = splitDataSet(dataSet, feat, value)
prob = len(subDataSet) / float(len(dataSet))
newEnt += prob*calcShannonEnt(subDataSet)
infoGain = baseEnt - newEnt
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = feat
else:
infoGain, bestMid = calccontinuousInfoGain(dataSet, feat)
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = feat
return bestFeature
def majorityCnt(dataSet):
return data['label'].value_counts().idxmax()
def createTree(dataSet):
if dataSet['label'].nunique() == 1:
return dataSet['label'].iloc[0]
if len(dataSet.columns) == 2:
return majorityCnt(dataSet)
bestFeat = chooseBestFeatureToSplit(dataSet)
myTree = {bestFeat:{}}
if dataSet[bestFeat].dtype == 'object':
for value in dataSet[bestFeat].unique():
myTree[bestFeat][value] = createTree(splitDataSet(dataSet, bestFeat, value))
else:
_, bestMid = calccontinuousInfoGain(dataSet, bestFeat)
sublessDataSet, submoreDataSet = splitcontinuousDataSet(dataSet, bestFeat, bestMid)
myTree[bestFeat]['<'+str(bestMid)] = createTree(sublessDataSet)
myTree[bestFeat]['>'+str(bestMid)] = createTree(submoreDataSet)
return myTree
createTree(data)
> {'texture': {'distinct': {'density': {'<0.38149999999999995': 0,
'>0.38149999999999995': 1}},
'little_blur': {'touch': {'soft_stick': 1, 'hard_smooth': 0}},
'blur': 0}}
三.
3.1 ID3算法
ID3算法的核心是在决策树各个结点上应用信息增益准则选择特征,递归地构建决策树。具体方法是:从根节点(root node)开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点;再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增益均很小或者没有特征可以选择为止。最后得到一个决策树。ID3相当于用极大似然法进行概率模型的选择
优点:
简单
缺点:
1)ID3算法采用信息增益来选择最优划分特征,然而人们发现,信息增益倾向与取值较多的特征,对于这种具有明显倾向性的属性,往往容易导致结果误差;
2)ID3算法没有考虑连续值,对与连续值的特征无法进行划分;
3)ID3算法无法处理有缺失值的数据;
4)ID3算法没有考虑过拟合的问题,而在决策树中,过拟合是很容易发生的;
5)ID3算法采用贪心算法,每次划分都是考虑局部最优化,而局部最优化并不是全局最优化,当然这一缺点也是决策树的缺点,获得最优决策树本身就是一个NP难题,所以只能采用局部最优;
3.2 C4.5算法
C4.5算法的提出旨在解决ID3算法的缺点 ,是对ID3算法的改进。
1)采用信息增益比来替代信息增益作为寻找最优划分特征,信息增益比的定义是信息增益和特征熵的比值,对于特征熵,特征的取值越多,特征熵就倾向于越大;
2)对于连续值的问题,将连续值离散化,在这里只作二类划分,即将连续值划分到两个区间,划分点取两个临近值的均值,因此对于m个连续值总共有m-1各划分点,对于每个划分点,依次算它们的信息增益,选取信息增益最大的点作为离散划分点;
3)对于缺失值的问题,我们需要解决两个问题,第一是在有缺失值的情况下如何选择划分的属性,也就是如何得到一个合适的信息增益比;第二是选定了划分属性,对于在该属性的缺失特征的样本该如何处理。
4)对于过拟合的问题,采用了后剪枝算法和交叉验证对决策树进行剪枝处理,这个在CART算法中一起介绍。
还存在的不足:
1)C4.5的剪枝算法不够优秀;
2)C4.5和ID3一样,都是生成的多叉树,然而在计算机中二叉树模型会比多叉树的运算效率高,采用二叉树也许效果会更好;
3)在计算信息熵时会涉及到大量的对数运算,如果是连续值还需要进行排序,寻找最优离散划分点,这些都会增大模型的运算;
4)C4.5算法只能处理分类问题,不能处理回归问题,限制了其应用范围。