作 者:echoy189
介 绍:spark数据处理与算法交流
公众号:spark推荐系统
决策树(decisiontree)学习的算法通常是一个递归地选择最优特征, 并根据该特征对训练数据进行分割, 使得各个子数据集有一个最好的分类的过程
目录
初步了解
决策树的类型
决策树过拟合
代码展示
一) 初步了解
目标:通过大量的数据生成一棵非常好的树,用这棵树来预测新来的数据
决策树的生成是数据不断分裂的递归过程。每一次分裂,尽可能让类别一样的数据在树的一边,当树的叶子节点的数据都是一类的时候,则停止分裂。(if else 语句即决策树模型本身)
当构建好一个判断模型后,新来一个用户后,可以根据构建好的模型直接进行判断,比如新用户特性为:无房产、单身、年收入55K。那么根据判断得出该用户无法进行债务偿还。这种决策对于借贷业务就有比较好指导意义
决策树的特点
1.可以处理非线性问题
2.可解释性强
3.模型简单,模型预测效率高
4.不太容易显示地使用函数表达,不可微
那么如何生成决策树呢?
1.将原始数据集筛选,分类成子数据集
a 每次分成几份?
b 以什么条件来分份?
2.对生成的子数据集不断分裂,什么时候停止?
3.利用最终生成的n份数据的共性来代表这个节点,什么是共性?
二) 决策树的类型
gini系数 (CART树,分类问题)
信息增益 (ID3,分类问题)
信息增益率 (C4.5,分类问题)
MSE (CART树,回归问题)
不同的分裂计算标准对用不同的类别的决策树
gini系数(用于CART分类树)
公式:
Gini系数越小,代表D集合中的数据越纯
多个节点的Gini系数:Gini(D)=|D1|/|D|*Gini(D1)+|D2|/|D|*GINI(D2)
分裂:优先选择( 分裂前的Gini系数 - 分类后的多个节点的Gini系数) 最大的分类条件
使用iris数据集来理解决策树(cart树)中的Gini系数
1.数据集中三种花(setosa,versicolor,virginica)每种各50个样本,一共150样本
2.第一次分裂条件为petal length(cm) <= 2.45
3.按照True/False分成左右两个子节点
4.左边(True)子节点的样本只有setosa 类50个 ,所以该节点的Gini系数为 Gini(D)=1-(1*1 +0+0 ) =0
5.右边(False)子节点的样本有versicolor和virginica各50个,所有该节点的Gini系数为
Gini(D) = 1- (0.5*0.5+0.5*0.5) = 0.5
6.然后右边节点在根据petal width(cm) <= 1.75 进行分类,又得到左右两个节点
7.同理左节点gini系数=0.168 ,右节点gini系数 =0.0425
注:Cart树 均为二叉树
信息增益(用于ID3分类树)
信息:I(X = xi)= -log₂p(xi)
信息熵:
一个集合中信息熵越低代表这个集合中的纯度越高(和gini系数一致)
信息增益:(分裂前的信息熵 - 分裂后的信息熵)
一次分裂后的信息增益越大,代表这次分裂提升的纯度就越高
举个例子来了解下信息增益
整体熵:
E(S) = -5/15• log₂(5/15) - 10/15•log₂(10/15) = 0.9182
性别熵
E(g男) = -3/8 • log₂(3/8) - 5/8 •log₂(5/8) = 0.9543
E(g女) = - 2/7• log₂(2/7) - 5/7 •log₂(5/7) = 0.8631
所以性别信息增益为:
IGain(S,g) =E(S) - 8/15E(g男) -7/15E(g女) =0.0064
活跃度熵
E(a高) = 0 E(a中) = 0.7219 E(a低) =0
所以活跃度信息增益为:
IGain(S,a) =E(S) - 6/15E(a高) - 5/15E(a中) -4/15E(a低) =0.6776
使用ID3会有一个问题,假如分裂条件选择的是uin,那么会分成15个节点。此时每个叶子节点只有一个样本,信息熵都为0,此时的信息增益最大。所以ID3会倾向于特征值多的特征去分裂,如果使用uin去分裂,其实并不是我们想看到的结果,相当于没分裂。在ID3算法的基础上,进行算法优化提出的一种算法C4.5
信息增益率(用于C4.5分类树)
对于多叉树,如果不限制分裂多少支,一次分裂就可以将信息熵降为0,如何平衡分裂情况与信息增益?
信息增益率:信息增益 /类别本身的熵
分别计算三种分类方法的信息增益率(还是上面的图)
GR(1)=Gain/-(6/15*log6/15+5/15*log5/15+4/15*log4/15)
GR(2)=Gain/-(11/15*log11/15+ 4/15*log4/15)
GR(3)=Gain/-(6/15*log6/15 + 9/15*log9/15)
MSE(用于CART回归树)
当前节点的预测值:取的是当前节点所有样本的y值加和 除以 样本个数 (即当前样本的label的均值)
每个节点的MSE:1/n•Σ|y -y_hat|
MSE增益:分裂前的MSE - 分类后的MSE
ID3/C4.5/CART算法总结
使用最多的还是CART树,既可以做分类又可以做回归
三) 决策树过拟合问题
如果训练模型的时候不限制树的生长,最终都会将结果分到最好(fully growntree) ,从而导致过拟合。可以通过树的剪枝,来防止过拟合
预剪枝
提前制定好规则人为的让树不能完全生长
max_depth:设置树的最大深度
min_samples_split:一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生
min_samples_leaf:一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生
max_features:max_features限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃
注:一般使用预剪枝防止决策树过拟合
后剪枝
首先利用训练集生成一棵fullygrown tree
把验证集中的数据放到tree中进行分类,记录结果得分
自叶子节点至根节点依次尝试cancle掉一次分裂(剪枝)记录验证集得分
选取最好得分最好的树的形态作为最终的树
四) 代码展示
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn import tree
data = pd.read_csv("D:\\study\\python_project\\my_study\\data\\mushrooms.csv")
# print(data)
data.head()
from sklearn.preprocessing import LabelEncoder
labelencoder = LabelEncoder()
for col in data.columns:
data[col] = labelencoder.fit_transform(data[col])
print(data.shape)
y = data['class']
X = data.drop('class',axis=1)
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=0,train_size=0.8)
columns = X_train.columns
print(columns)
# 数据标准化
from sklearn.preprocessing import StandardScaler
ss_X = StandardScaler()
ss_y = StandardScaler()
X_train = ss_X.fit_transform(X_train)
X_test = ss_X.transform(X_test)
from sklearn.tree import DecisionTreeClassifier
model_tree = DecisionTreeClassifier()
model_tree.fit(X_train,y_train)
y_prob = model_tree.predict_proba(X_test)[:,1]
y_pred = np.where(y_prob > 0.5,1,0)
model_tree.score(X_test,y_pred)
# 可视化树图
data_ = pd.read_csv("D:\\study\\python_project\\my_study\\data\\mushrooms.csv")
data_feature_name = data_.columns[1:]
data_target_name = np.unique(data_["class"])
import pydotplus
from sklearn import tree
from IPython.display import Image
import os
os.environ["PATH"] += os.pathsep + 'D:\solt\graphviz\bin'
dot_tree = tree.export_graphviz(model_tree,out_file=None,feature_names=data_feature_name,class_names=data_target_name,filled=True, rounded=True,special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_tree)
img = Image(graph.create_png())
graph.write_png("out.png")
往期精选
机器学习|主成分分析法PCA机器学习-线性回归(一)
机器学习-线性回归(二)
机器学习|梯度下降法
机器学习|逻辑回归
长按识别二维码关注我