CART决策树python实现

CART决策树代码实现

调包侠:

from sklearn import tree
import pydotplus

def cart_skl_test():
    df = pd.read_csv("../dataSet/liquefaction_data_MLE.csv")
    x = df[['CSR', 'Vs']]
    y = df['target']

    clf = tree.DecisionTreeClassifier()
    clf.fit(x, y)
    dot_data = tree.export_graphviz(clf, out_file=None)
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_png("cartTree.png")


详细代码:

  • 导入包
import pandas as pd
import math
  • 定义计算 G i n i Gini Gini 值的函数
def get_gini(dataSet):
    num_instances = len(dataSet)	# 数据个数
    label_counts = {}	# 统计当前各标签数据量
    
    for featVec in dataSet:
        current_label = featVec[-1]
        if current_label not in label_counts.keys():
            label_counts[current_label] = 0
        label_counts[current_label] += 1
    sum_prob = 0.0
    for key in label_counts:
        prob = float(label_counts[key]) / num_instances
        sum_prob = sum_prob + math.pow(prob, 2)
    gini = 1 - sum_prob
    return gini
  • d a t a S e t dataSet dataSet 是数据集, a x i s axis axis 是第几个特征, v a l u e value value 是该特征的取值。

    该函数是根据数据集中第 a x i s axis axis个特征的值与 v a l u e value value值比较,对数据进行划分。

def splitDataSet(dataSet, axis, value):
    leftDataSet = []
    rightDataSet = []
    for featVec in dataSet:
        if featVec[axis] <= value:
            leftDataSet.append(featVec)
        else:
            rightDataSet.append(featVec)
    # print(leftDataSet)
    # print(rightDataSet)
    return leftDataSet, rightDataSet
  • 选择最好的属性分割点即基尼指数最大分割点(第一层循环为属性遍历,第二层循环为遍历某个属性的各个分割点)
def chooseBestFeatureToSplit(dataSet):
    # 决策属性不算
    numFeatures = len(dataSet[0]) - 1
    bestInfoGini = 1.0
    bestFeature = -1
    bestSplitValue = -1
    baseGini = get_gini(dataSet)

    for i in range(numFeatures):
        # 把第i列属性的值取出来生成一维数组
        featList = [example[i] for example in dataSet]
        # 剔除重复值,并排序
        uniqueVals = list(set(featList))
        uniqueVals.sort()
        
        featureSplit = -1   # 当前属性下的最佳分割点
        featureGini = 1.0   # 当前属性下的最小Gini值

        # 选择当前属性下的最佳分割点
        for j in range(len(uniqueVals) - 1):
            value = (uniqueVals[j] + uniqueVals[j+1]) / 2
            left_dataSet, right_dataSet = splitDataSet(dataSet, i, value)
            prob = len(left_dataSet) / float(len(dataSet))
            currentGini = prob * get_gini(left_dataSet) + (1 - prob) * get_gini(right_dataSet)
            if currentGini < featureGini:
                featureGini = currentGini
                featureSplit = value

        # 选择最佳属性及其分割点
        if featureGini < bestInfoGini:
            bestInfoGini = featureGini
            bestFeature = i
            bestSplitValue = featureSplit

    print("bestFeature: {}, bestSplitValue: {}, Gini: {}".format(bestFeature, bestSplitValue, baseGini))

    return bestFeature, bestSplitValue, bestInfoGini
  • 创建 C A R T CART CART 决策树
def createTree(dataSet, paraFeatureName):
    # 拷贝标签
    classList = [example[-1] for example in dataSet]
    
    # 当结点中所有标签相同时-->叶子结点
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    
    bestFeat, bestSplit, gini = chooseBestFeatureToSplit(dataSet)
    bestFeatureName = paraFeatureName[bestFeat]
    myTree = {bestFeatureName: {}}	#运用字典存储树
    
    # 递归建立树
    leftTree, rightTree = splitDataSet(dataSet, bestFeat, bestSplit)
    myTree[bestFeatureName]["<=" + str(bestSplit)] = createTree(leftTree, paraFeatureName)
    myTree[bestFeatureName][">" + str(bestSplit)] = createTree(rightTree, paraFeatureName)
    return myTree
  • 代码执行
if __name__ == "__main__":
    # cart_skl_test()
    df = pd.read_csv("../dataSet/liquefaction_data_MLE.csv")	#读取.csv数据
    featureName = df.columns.values
    dataSet = []
    for i in df.values:
        dataSet.append(i)
    tree = createTree(dataSet, featureName)
    print(tree)


结果展示

自写代码结果展示:
{'Vs': 
	{'<=16.35': 
		{'Vs': 
			{'<=6.550000000000001':
    			0.0, 
    		'>6.550000000000001': 
    			{'Vs': 
    				{'<=11.2': 
    					1.0, 
    				'>11.2': 
    					{'Vs':
                        	{'<=15.55': 
    							{'CSR': 
    								{'<=0.16999999999999998': 
    									0.0, 
    								'>0.16999999999999998': 
    								{'CSR': 
    									{'<=0.26': 
    										{'Vs': 
    											{'<=11.45': 
    												0.0, 
    											'>11.45': 
    												1.0}}, 
    									'>0.26': 
    										{'Vs': 
    											{'<=13.25': 
    												1.0, 
    											'>13.25': 
    												0.0}}}}}}, 
    						'>15.55': 1.0}}}}}}, 
	'>16.35': 
		{'CSR': 
			{'<=0.29500000000000004': 
				0.0, 
			'>0.29500000000000004': 
				{'CSR': 
					{'<=0.32999999999999996': 
						1.0, 
					'>0.32999999999999996':
        				0.0}}}}}}
	{'<=0.29500000000000004': 
				0.0, 
			'>0.29500000000000004': 
				{'CSR': 
					{'<=0.32999999999999996': 
						1.0, 
					'>0.32999999999999996':
        				0.0}}}}}}

调包结果展示:

在这里插入图片描述

  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值