基本的决策树笔记
决策树是个有监督机器学习,相比于其他机器学习算法,决策树的优点是:
- 可以处理非线性的数据,处理离散型数据能力强
- 可解释性强,没有θ
以sklearn.datasets中的鸢尾花数据集讲解这个决策树:
-
鸢尾花数据集介绍:花有四种特征:[sepal length(花萼长度), sepal width(花萼宽度), petal length(花瓣长度),petal width(花瓣宽度)]。 有三种类型的花分别是:[setosa, versicolor, virginica],也就是花有四种特征,根据这四种特征我们能判断花的种类,这里花的种类有三种。
-
决策树算法:说到机器学习,一般都会去构建损失函数,但是决策树是一种贪婪的学习方法,它会一直沿着当前的最优路径去分裂。如何选择最优路径呢?有两种比较常见的指标:gini, c4.5。下面用决策树训练鸢尾花数据集分类流程:
gini系数介绍:
gini系数说白了就是数据的纯度:gini系数公式:
所以,gini=0.663=1-[(29/90)**2+(34/90)**2+(27/90)**2],这是单个节点的gini系数计算方法。
上图的gini系数该怎么算呢?首先应该按照上面单个框的gini系数进行计算,然后这一层gini系数=29/(29+61)*0+61/(29+61)*0.493=0.159
则分裂指标=上一层gini系数-这一层gini系数=0.663-0.159=0.334,这个gini系数降的特别多,因为是贪婪的分叉,所以这次的gini系数下降的最多,分叉使用的花的特征也是对花进行分类的最佳特征。
c4.5介绍:
说到c4.5不得不提熵这个概念:熵就是数据的不确定性,数据越纯(种类越少),数据的熵就越小,与gini系数很相似。计算熵的公式:
举个例子:
29/90 = 0.32
34/90 = 0.38
27/90 = 0.3
上图的熵 = -(0.32log(0.32, 2)+0.38log(0.38, 2)+0.3*log(0.3, 2))=1.578, 对各节点熵的和与gini系数类似,都是乘对应的比例再相加。
c4.5(信息增益率)是有id3(信息增益)演变过来的。
id3(信息增益) = 上一层的熵-这一层的熵
c4.5 = id3/分类条件对应的熵
举个例子:在某个条件下,a分为了(29, 34,27),则分类条件对应的熵=1.578,
上面已经算过了
同样,在c4.5的分裂指标下,数据按照c4.5(信息增益率)最大的分裂条件进行分裂。
gini和c4.5的区别:gini是二叉分裂, c4.5是多叉分裂,当然还有很多其他的不同,就不一一细说了,详情可见https://www.cnblogs.com/TimVerion/p/11211749.html
下面给出决策树对鸢尾花数据集分类的完整代码:
import numpy as np
import pandas as pd
# 用pd.DataFrame()整理iris
from sklearn.datasets import load_iris
iris = load_iris()
data = iris['data']
data = pd.DataFrame(data, columns=iris['feature_names'])
data['target'] = iris['target']
# print(iris.keys())
# print(data)
# 有监督机器学习,提取X,y
X = data.iloc[:,0:4]
y = data.iloc[:,4]
# 分离训练数据和测试数据
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
# 训练模型,保存模型
from sklearn.tree import DecisionTreeClassifier
import pickle
dec_tree = DecisionTreeClassifier(criterion="gini")
dec_tree.fit(X_train, y_train)
file = open("dec_tree.pkl", "wb")
pickle.dump(dec_tree, file)
file.close()
# 画决策树图
# from sklearn.tree import export_graphviz
# import pickle
# file = open("dec_tree.pkl","rb")
# model = pickle.load(file)
# file.close()
# print(model)
# export_graphviz(
# model,
# out_file="iris_tree.dot",
# feature_names=iris.feature_names[:],
# class_names=iris.target_names,
# rounded=True,
# filled=True
# )
# $ dot -Tpng iris_tree.dot -o tree.png (PNG format)
# 测试模型
# from sklearn.metrics import accuracy_score
# import pickle
# file = open("dec_tree.pkl","rb")
# model = pickle.load(file)
# file.close()
# y_hat = model.predict(X_test)
# y_hat = y_hat.reshape(-1,1)
# kk = pd.DataFrame(y_hat, columns=["y_hat",])
# kk['y_true'] = np.array(y_test)
# print(kk)
# print(accuracy_score(y_test, y_hat))