目录
引入:举个简单的例子,比方买个水杯,首先看保温效果如何,选个保温效果好的后看水杯价格有没有超预算,在没超预算的情况下看好不好看,颜值高果断拿下。这样的思维方式实质上就是在构建决策树,通过保温效果、预算、颜值这几个特征选出是否购买(ps这个属于二叉树的例子了)。
一、概念
1、概念
决策树分二叉树和非二叉树,通过分析某些特征对数据集进行划分,从而预测新的待测样本。简单来说就是想象一棵树,根节点代表我们对数据集特征选取,而树枝和叶子代表在特征属性下对数据集的不同判断和结果。
2、结束条件
(1)没有特征可以往下分了
(2)剩下的样本都属于同一个类别
(3)剩下的样本数少于设定的阈值
(4)达到设定的深度
3、熵
那么决策树如何构建?怎么决定哪个特征先判断?——由此,引出熵这个概念,即表示不确定性、混乱程度。熵越大越混乱。
熵的表达式为:
后续计算信息增益/增益率需要用到信息熵。
二、算法&具体例题计算
1、ID3信息增益
1)计算集合D的信息熵H(D):
2)计算特征A条件下D的信息熵H(D|A),即A的影响程度大不大:
3)最后算差值计算信息增益:Gain(D,A) = H(D) - H(A)
2、C4.5 信息增益率
其实就是上面的信息增益 / 特征A的信息熵:
3、基尼指数
基尼指数越小纯度越高。
1)计算整个样本D的基尼指数:
2)计算在特征A的条件下集合D的基尼指数,即A用于判断集合D的影响:
3)最后选择基尼指数较小的特征进行划分。
二-一 例题
下面举个具体的例子来说明这三个算法的决策树构建:
样本 | 属性 | 分类 | |
x1 | x2 | ||
1 | T | T | √ |
2 | T | F | √ |
3 | T | F | × |
4 | F | T | √ |
5 | F | T | × |
6 | F | T | √ |
已知:
分别用不同算法构造决策树
(1)ID3信息增益:
H(D) = -(4/6)log(4/6) - (2/6)log(2/6) =0.918
H(T1) = -(2/3)log(2/3) - (1/3)log(1/3) = 0.918
H(F1) = -(2/3)log(2/3) - (1/3)log(1/3) = 0.918
Gain(D,x1) = H(D) - [ (1/2)H(T1) + (1/2)H(F1) ] = 0
H(T2) = -(3/4)log(3/4)-(1/4)log(1/4) = 0.811
H(F2) = -(1/2)log(1/2) - (1/2)log(1/2) = 1
Gain(D,x2) =H(D) -[ (4/6)H(T2) + (2/6)H(F2) ] = 0.918-0.874 = 0.044
0.044>0,选x2为第一个特征。
(2)Gini
0.444>0.417,基尼指数越小纯度越高所以选X2作为划分特征。
三、python实例
1、基尼指数
(1)构建gini代码计算acc
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from sklearn import tree
def load_date():
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307, ),(0.3081, ))])
dataset_train = datasets.MNIST(root='../data/minist',train=True,download=True,transform=transform)
dataset_test = datasets.MNIST(root='../data/minist',train=False,download=True,transform=transform)
X_train = dataset_train.data.numpy()
X_test = dataset_test.data.numpy()
X_train = np.reshape(X_train,(60000,784))
X_test =np.reshape(X_test,(10000,784))
Y_train = dataset_train.targets.numpy()
Y_test = dataset_test.targets.numpy()
return X_train,Y_train,X_test,Y_test
if __name__ == '__main__':
train_x,train_y,test_x,test_y = load_date()
cart = tree.DecisionTreeClassifier(criterion='gini',max_depth=8,random_state=5)
# cart = tree.DecisionTreeClassifier(criterion='entropy',max_depth=8)
cart = cart.fit(train_x,train_y)
acc = cart.score(test_x,test_y)
print("准确率:",acc)
(2)决策树可视化
#可视化
plt.figure(figsize=(12.8, 6.4))
plt.subplot(121)
tree.plot_tree(cart)
plt.title("Gini")
plt.show()
2、ID3信息增益
(1)构建ID3代码计算acc
if __name__ == '__main__':
train_x,train_y,test_x,test_y = load_data()
# cart = tree.DecisionTreeClassifier(criterion='gini',max_depth=3,random_state=5)
cart = tree.DecisionTreeClassifier(criterion='entropy',max_depth=8)
cart = cart.fit(train_x,train_y)
acc = cart.score(test_x,test_y)
print("准确率:",acc)
(2)决策树可视化
#可视化
plt.figure(figsize=(12.8, 6.4))
plt.subplot(121)
tree.plot_tree(cart)
plt.title("ID3")
plt.show()