[机器学习]决策树(CART)源码(python)

详细的决策树(C4.5, ID3, CART)介绍和公式在前面的博文:决策树详解
个人觉得李航博士的<统计学习方法>对机器学习一些算法解释的也很好.
详细的代码可参考知乎:知乎

本文主要是对CART决策树的实现细节做代码展示,代码有比较详细的注解.
数据集有四个特征,最后一个为label–(SepalLength, SepalWidth, PetalLength, PetalWidth, Name),总共150条样本,每个类别150条,Name为(setosa, versicolor, virginica)三类.数据集大概如下:
在这里插入图片描述
先附上一张最后决策树的图片(代码没有实现这个功能).
在这里插入图片描述
具体代码如下:

# -*- coding: utf-8 -*-

import numpy as np


class Tree:
    def __init__(self, value=None, trueBranch=None, falseBranch=None, results=None, col=-1, data=None):
        self.value = value
        self.trueBranch = trueBranch
        self.falseBranch = falseBranch
        self.results = results
        self.col = col
        self.data = data

    def __str__(self):
        print(self.col, self.value)
        print(self.results)
        return ""


def calculateDiffCount(datas):
    """
    该函数是计算gini值的辅助函数,假设输入的dataSet为为['A', 'B', 'C', 'A', 'A', 'D'],
    则输出为['A':3,' B':1, 'C':1, 'D':1],这样分类统计dataSet中每个类别的数量
    """
    # 计算个全部样本属于的每个类别的个数.返回一个字典{类别1:个数, 类别2:个数,.....} 其实就是计算Ck
    results = {}
    for data in datas:
        # 按行来读取数据,data[-1]表示的就是y值,y值就是分类的类别取值
        # data[-1] means dataType
        if data[-1] not in results:
            results.setdefault(data[-1], 1)
        else:
            results[data[-1]] += 1
    return results


def gini(rows):
    # 计算gini的值(Calculate GINI)
    # 每个类别的个数/样本个数
    length = len(rows)
    results = calculateDiffCount(rows)
    imp = 0.0
    for i in results:
        imp += results[i] / length * results[i] / length
    return 1 - imp


def splitDatas(rows, value, column):
    # 根据条件分离数据集(splitDatas by value, column)
    # return 2 part(list1, list2)
    # rows就是数据(D),column表示特征(A),value表示取值标准(A=?/value)
    list1 = []
    list2 = []

    if isinstance(value, int) or isinstance(value, float):
        for row in rows:
            if row[column] >= value:
                list1.append(row)
            else:
                list2.append(row)
    else:
        for row in rows:
            if row[column] == value:
                list1.append(row)
            else:
                list2.append(row)
    return list1, list2


def buildDecisionTree(rows, evaluationFunction=gini):
    # 递归建立决策树, 当gain=0,时停止递归
    # build decision tree bu recursive function
    # stop recursive function when gain = 0
    # return tree
    currentGain = evaluationFunction(rows)
    column_lenght = len(rows[0])
    rows_length = len(rows)

    best_gain = 0.0
    best_value = None
    best_set = None

    # choose the best gain
    for col in range(column_lenght - 1):
        col_value_set = set([x[col] for x in rows])
        for value in col_value_set:
            list1, list2 = splitDatas(rows, value, col)
            p = len(list1) / rows_length
            gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
            if gain > best_gain:
                best_gain = gain
                best_value = (col, value)
                best_set = (list1, list2)

    # stop or not stop
    if best_gain > 0:
        trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
        falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
        return Tree(col=best_value[0], value=best_value[1], trueBranch=trueBranch, falseBranch=falseBranch)
    else:
        return Tree(results=calculateDiffCount(rows), data=rows)


def classify(data, tree):
    if tree.results != None:
        return tree.results
    else:
        branch = None
        v = data[tree.col]
        if isinstance(v, int) or isinstance(v, float):
            if v >= tree.value:
                branch = tree.trueBranch
            else:
                branch = tree.falseBranch
        else:
            if v == tree.value:
                branch = tree.trueBranch
            else:
                branch = tree.falseBranch
        return classify(data, branch)


def loadCSV():
    def convertTypes(s):
        s = s.strip()
        try:
            return float(s) if '.' in s else int(s)
        except ValueError:
            return s
    data = np.loadtxt("fishiris.csv", dtype='str', delimiter=',')
    data = data[1:, :]
    dataSet =([[convertTypes(item) for item in row] for row in data])
    return dataSet


if __name__ == '__main__':
    dataSet = loadCSV()
    decisionTree = buildDecisionTree(dataSet, evaluationFunction=gini)
    test_data = [4.9,3,1.4,0.2]
    r = classify(test_data, decisionTree)
    print(r)

预测结果如下:
在这里插入图片描述

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值