【机器学习】02 决策树CART代码


1.引入库

import math
import operator

2.读入数据

def createDataset():
    dataSet = [
        ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
        ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
        ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
        ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
        ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
        ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
        ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
        ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
    ]

    // 特征值列表
    labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']

    return dataSet, labels

3.找到样本最多的类

def majorityCnt(classList):
    classCount={}

    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1

    //降序
    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    print(type(sortedClassCount))
    print(sortedClassCount)
    return sortedClassCount[0][0]

4.计算基尼值

def calcGini(dataSet):
    numEntries=len(dataSet)
    labelCounts={}

    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1

    for key in labelCounts:
        labelCounts[key]/=numEntries
        labelCounts[key]=labelCounts[key]*labelCounts[key]
        
    Gini=1-sum(labelCounts.values())
    return Gini

5.划分数据集

def splitDataSet(dataSet,axis,value): 
    retDataSet1=[]
	retDataSet2=[]
    for featVec in dataSet:
    	reducedFeatVec=[]
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet1.append(reducedFeatVec)
		else:
			reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet2.append(reducedFeatVec)
         
    return retDataSet1,retDataSet2

6.找出基尼指数最小的值

def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1
    if numFeatures==0:
    	return 0
    bestGini=1
    bestFeature=-1

    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        uniqueVals=set(featList)
        Gini={}

        for value in uniqueVals:
            subDataSet1,subDataSet2=splitDataSet(dataSet,i,value)
            prob1=len(subDataSet1)/float(len(dataSet))
            prob2=len(subDataSet2)/float(len(dataSet))
            subDataSet1Gini=calcGini(subDataSet1)
            subDataSet2Gini=calcGini(subDataSet2)
            Gini[value]=prob1*subDataSet1Gini+prob2*subDataSet2Gini
            
       		if Gini[value]<bestGini:
       			bestGini=Gini[value]
       			bestFeature=i
       			bestSplit=value

    return bestFeature,bestSplit

7.创建树

def createTree(dataSet,labels):
    //获得每一个标签
    classList=[example[-1] for example in dataSet]

    //标签全相同即全属于同一类别,返回该标签
    if classList.count(classList[0])==len(dataSet):
        return classList[0]
    //所有样本在所有属性上取值相同,类别标记为样本数最多的类
    if len(dataSet[0])==1:
        return majorityCnt(classList)

    //获取最优索引
    bestFeat,bestSplit=chooseBestFeatureToSplit(dataSet)
    //获取最优索引的名称
    bestFeatLabel=labels[bestFeat]

    //创建根节点
    myTree={bestFeatLabel:{}}
    //删除用过的结点
    del(labels[bestFeat])
    subLabels=labels[:]
    subDataSet1,subDataSet2=splitDataSet(dataSet,bestFeat,bestSplit)
    myTree[bestFeatLabel][bestSplit]=createTree(subDataSet1,subLabels)
    myTree[bestFeatLabel]['others']=createTree(subDataSet2,subLabels)

    return myTree

8.运行结果

dataSet,labels=createDataset()
myTree=createTree(dataSet,labels)
TreePlotter.createPlot(myTree)
print(myTree)

{‘纹理’: {‘清晰’: {‘触感’: {‘硬滑’: ‘好瓜’, ‘others’: {‘色泽’: {‘青绿’: {‘根蒂’: {‘稍蜷’: ‘好瓜’, ‘others’: ‘坏瓜’}}, ‘others’: ‘坏瓜’}}}}, ‘others’: {‘色泽’: {‘乌黑’: {‘敲击’: {‘沉闷’: ‘坏瓜’, ‘others’: ‘好瓜’}}, ‘others’: ‘坏瓜’}}}}

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值