Python 手写实现鸢尾花决策树

​Iris数据集是常用的分类实验数据集,早在1936年,模式识别的先驱Fisher就在论文中使用了它 (直至今日该论文仍然被频繁引用)。

在这里插入图片描述

Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性:花萼长度(sepal length),花萼宽度(sepal width),花瓣长度(petal length),花瓣宽度(petal width),可通过4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。在三个类别中,其中有一个类别和其他两个类别是线性可分的。

在sklearn中已内置了此数据集。

核心算法思想就是:

每到一个节点,对于现在数据集中四个属性的所有值,根据小于和大于此属性值分成两个数据集,并计算香农熵,取所有香农熵最大增益的那个特征及值作为划分标准;划分直到所有数据均为一类数据。

为了说明这一思想,举个例子:

现在数据集中有三个数据:

1:0.1 0.2 0.3 0.4 0

2:0.2 0.3 0.4 0.5 1

3:0.3 0.4 0.4 0.6 2

此时,数据集不是同一类数据,所以要根据某个标准进行划分

为了找到那个标准,先考察pos=0,即第一个特征。将第一个特征的三个值0.1,0.2,0.3分别作为标准,比如:

将0.1作为标准,pos=0小于等于0.1的数据,即数据1划分为一个数据集,大于0.1的数据集,即数据2,3划分为一个数据集

则得到两个数据集:{1},{2,3}

计算此时的香农熵,接下来计算第一个特征的其余值作为标准时的香农熵,之后计算出第2,3,4个属性所有值的香农熵,这些香农熵中最大的那个pos和值即为我们决策树的节点

下面是这一逻辑的代码:

#选择最好的特征值进行分类
def choose_best_split(data_set):
    base_Ent=calculate_Ent(data_set)
    best_increase=0.0
    best_feature=[-1,-1]
    for i in range(4):
        features=[j[i] for j in data_set]
        unique=set(features)
        for feature in unique:
            less_Set,more_Set=spliit_Set(data_set, i, feature)
            tmp=len(less_Set)/float(len(data_set))
            new_Ent=tmp*calculate_Ent(less_Set)
            new_Ent+=(1-tmp)*calculate_Ent(more_Set)
            increase=base_Ent-new_Ent
            if increase>best_increase:
                best_increase=increase
                best_feature=[i,feature]
    return best_feature,best_increase

一、数据集初始化

将标签附到特征值之后:

#初始化数据集
def init_data_set():
    iris = load_iris()  #导入数据集iris
    iris_feature = iris.data.tolist()    #特征数据
    iris_target = iris.target.tolist()   #分类数据
    for i in range(len(iris_feature)):
        iris_feature[i].append(iris_target[i])
    return iris_feature

二、划分数据集

将数据集划分成训练集和测试集

#划分数据集
def create_set(data_set,split_rate=0.8):
    #0的是测试集,1的是训练集
    length=len(data_set)
    train_num=int(length*split_rate)
    test_num=length-train_num
    random_list=[1]*train_num
    random_list.extend([0]*test_num)
    random.shuffle(random_list)
    test_set=[]
    train_set=[]
    for i in range<
  • 13
    点赞
  • 91
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值