决策树笔记一

基本的决策树笔记

决策树是个有监督机器学习,相比于其他机器学习算法,决策树的优点是:

  1. 可以处理非线性的数据,处理离散型数据能力强
  2. 可解释性强,没有θ

以sklearn.datasets中的鸢尾花数据集讲解这个决策树:

  1. 鸢尾花数据集介绍:花有四种特征:[sepal length(花萼长度), sepal width(花萼宽度), petal length(花瓣长度),petal width(花瓣宽度)]。 有三种类型的花分别是:[setosa, versicolor, virginica],也就是花有四种特征,根据这四种特征我们能判断花的种类,这里花的种类有三种。

  2. 决策树算法:说到机器学习,一般都会去构建损失函数,但是决策树是一种贪婪的学习方法,它会一直沿着当前的最优路径去分裂。如何选择最优路径呢?有两种比较常见的指标:gini, c4.5。下面用决策树训练鸢尾花数据集分类流程:

在这里插入图片描述
gini系数介绍:
在这里插入图片描述
gini系数说白了就是数据的纯度:gini系数公式:在这里插入图片描述
所以,gini=0.663=1-[(29/90)**2+(34/90)**2+(27/90)**2],这是单个节点的gini系数计算方法。
在这里插入图片描述
上图的gini系数该怎么算呢?首先应该按照上面单个框的gini系数进行计算,然后这一层gini系数=29/(29+61)*0+61/(29+61)*0.493=0.159
则分裂指标=上一层gini系数-这一层gini系数=0.663-0.159=0.334,这个gini系数降的特别多,因为是贪婪的分叉,所以这次的gini系数下降的最多,分叉使用的花的特征也是对花进行分类的最佳特征。

c4.5介绍:
说到c4.5不得不提熵这个概念:熵就是数据的不确定性,数据越纯(种类越少),数据的熵就越小,与gini系数很相似。计算熵的公式:
在这里插入图片描述
举个例子:
在这里插入图片描述
29/90 = 0.32
34/90 = 0.38
27/90 = 0.3
上图的熵 = -(0.32log(0.32, 2)+0.38log(0.38, 2)+0.3*log(0.3, 2))=1.578, 对各节点熵的和与gini系数类似,都是乘对应的比例再相加。
c4.5(信息增益率)是有id3(信息增益)演变过来的。

id3(信息增益) = 上一层的熵-这一层的熵


c4.5 = id3/分类条件对应的熵
举个例子:在某个条件下,a分为了(29, 34,27),则分类条件对应的熵=1.578,
上面已经算过了

同样,在c4.5的分裂指标下,数据按照c4.5(信息增益率)最大的分裂条件进行分裂。

gini和c4.5的区别:gini是二叉分裂, c4.5是多叉分裂,当然还有很多其他的不同,就不一一细说了,详情可见https://www.cnblogs.com/TimVerion/p/11211749.html

下面给出决策树对鸢尾花数据集分类的完整代码:

import numpy as np
import pandas as pd

# 用pd.DataFrame()整理iris
from sklearn.datasets import load_iris
iris = load_iris()
data = iris['data']
data = pd.DataFrame(data, columns=iris['feature_names'])
data['target'] = iris['target']
# print(iris.keys())
# print(data)

# 有监督机器学习,提取X,y
X = data.iloc[:,0:4]
y = data.iloc[:,4]

# 分离训练数据和测试数据
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)

# 训练模型,保存模型
from sklearn.tree import DecisionTreeClassifier
import pickle
dec_tree = DecisionTreeClassifier(criterion="gini")
dec_tree.fit(X_train, y_train)
file = open("dec_tree.pkl", "wb")
pickle.dump(dec_tree, file)
file.close()

# 画决策树图
# from sklearn.tree import export_graphviz
# import pickle
# file = open("dec_tree.pkl","rb")
# model = pickle.load(file)
# file.close()
# print(model)
# export_graphviz(
#     model,
#     out_file="iris_tree.dot",
#     feature_names=iris.feature_names[:],
#     class_names=iris.target_names,
#     rounded=True,
#     filled=True
# )
# $ dot -Tpng iris_tree.dot -o tree.png    (PNG format)


# 测试模型
# from sklearn.metrics import accuracy_score
# import pickle
# file = open("dec_tree.pkl","rb")
# model = pickle.load(file)
# file.close()
# y_hat = model.predict(X_test)
# y_hat = y_hat.reshape(-1,1)
# kk = pd.DataFrame(y_hat, columns=["y_hat",])
# kk['y_true'] = np.array(y_test)
# print(kk)
# print(accuracy_score(y_test, y_hat))



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值