Task02:基于决策树的分类预测

在决策树的算法中,建立决策树的关键,即在当前状态下选择哪个属性作为分类依据。根据不同的目标函数,建立决策树主要有一下三种算法:

  1. ID3
  2. C4.5
  3. CART

主要的区别就是选择的目标函数不同,ID3使用的是信息增益,C4.5使用信息增益率,CART使用的是Gini系数。

  • 信息熵是一种衡量数据混乱程度的指标,信息熵越小,则数据的“纯度”越高.

  • 熵H(Y)与条件熵H(Y|X)之差称为互信息。决策树学习中的信息增益等价于训练数据集中类与特征的互信息。

DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion=‘entropy’,
max_depth=None, max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=‘deprecated’,
random_state=None, splitter=‘best’)

决策树模型,其函数的参数含义如下所示:

  1. criterion:gini或者entropy,前者是基尼系数,后者是信息熵。
  2. splitter: best or random 前者是在所有特征中找最好的切分点 后者是在部分特征中,默认的”best”适合样本量不大的时候,而如果样本数据量非常大,此时决策树构建推荐”random” 。
  3. max_features:None(所有),log2,sqrt,N 特征小于50的时候一般使用所有的
  4. max_depth: int or None, optional (default=None) 设置决策随机森林中的决策树的最大深度,深度越大,越容易过拟合,推荐树的深度为:5-20之间。
  5. min_samples_split:设置结点的最小样本数量,当样本数量可能小于此值时,结点将不会在划分。
  6. min_samples_leaf: 这个值限制了叶子节点最少的样本数,如果某叶子节点数目小于样本数,则会和兄弟节点一起被剪枝。
  7. min_weight_fraction_leaf: 这个值限制了叶子节点所有样本权重和的最小值,如果小于这个值,则会和兄弟节点一起被剪枝默认是0,就是不考虑权重问题。
  8. max_leaf_nodes: 通过限制最大叶子节点数,可以防止过拟合,默认是"None”,即不限制最大的叶子节点数。
  9. class_weight: 指定样本各类别的的权重,主要是为了防止训练集某些类别的样本过多导致训练的决策树过于偏向这些类别。这里可以自己指定各个样本的权重,如果使用“balanced”,则算法会自己计算权重,样本量少的类别所对应的样本权重会高。
  10. min_impurity_split: 这个值限制了决策树的增长,如果某节点的不纯度(基尼系数,信息增益,均方差,绝对差)小于这个阈值则该节点不再生成子节点。即为叶子节点 。
  • 决策树的主要优点:

    1. 具有很好的解释性,模型可以生成可以理解的规则。
    2. 可以发现特征的重要程度。
    3. 模型的计算复杂度较低。
  • 决策树的主要缺点:

    1. 模型容易过拟合,需要采用减枝技术处理。
  1. 不能很好利用连续型特征。
  2. 预测能力有限,无法达到其他强监督模型效果。
  3. 方差较高,数据分布的轻微改变很容易造成树结构完全不同。

1. Demo 实践

1.1 库函数导入

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

1.2 训练模型

x=np.array([[1,2],[3,4],[-1,-2],[-4,-5],[1,-2]])
y=np.array([1,1,0,0,0])
tree_model=DecisionTreeClassifier()
tree_model=tree_model.fit(x,y)

1.3 数据和模型可视化

 plt.figure(figsize=(10,6))
plt.scatter(x[:,0],x[:,1],c=y)
#plt.plot(x[:,0],x[:,1])
plt.xticks(range(-5,5,2))
plt.yticks(range(-6,7))
plt.title("scatter")
plt.xlabel('x0')
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-90FQc9XE-1598089259051)(output_10_0.png)]

## 可视化决策树

import graphviz

#sklearn库集成了graphviz库中的export_graphviz方法,作为sklearn中tree对象的属性。
#因此,在tree对象即决策树模型已经训练完毕的前提下,可以通过tree.export_graphviz()输出能被graphviz库处理的.dot文件
#dot_data=tree.export_graphviz(model_tree, out_file=None, max_depth=5, feature_names=names_list, filled=True,rounded=True)
#dot_data = tree.export_graphviz(ID3, out_file=None,feature_names=data.columns[:-1],class_names=np.unique(y))
dot_data = tree.export_graphviz(tree_model, out_file=None)
#使用pydotplus库调用graph_from_dot_data()方法将生成的.dot文件转置为.graph图形对象。
#graph图形对象不能直接可视化,可以通过write_pdf/write_jpg等方法转置为可打开的文件查看。
graph = graphviz.Source(dot_data) 
#使用render方法将保存dot源码, 并且会渲染图形, 使用view=True参数可以自动打开应用程序以便浏览生成的图:
graph.render("pengunis", view=True)
'pengunis.pdf'

在这里插入图片描述

conda install python-graphviz

数据信息简单查看

data.head()
studyNameSample NumberSpeciesRegionIslandStageIndividual IDClutch CompletionDate EggCulmen Length (mm)Culmen Depth (mm)Flipper Length (mm)Body Mass (g)SexDelta 15 N (o/oo)Delta 13 C (o/oo)Comments
0PAL07081Adelie Penguin (Pygoscelis adeliae)AnversTorgersenAdult, 1 Egg StageN1A1Yes2007-11-1139.118.7181.03750.0MALENaNNaNNot enough blood for isotopes.
1PAL07082Adelie Penguin (Pygoscelis adeliae)AnversTorgersenAdult, 1 Egg StageN1A2Yes2007-11-1139.517.4186.03800.0FEMALE8.94956-24.69454NaN
2PAL07083Adelie Penguin (Pygoscelis adeliae)AnversTorgersenAdult, 1 Egg StageN2A1Yes2007-11-1640.318.0195.03250.0FEMALE8.36821-25.33302NaN
3PAL07084Adelie Penguin (Pygoscelis adeliae)AnversTorgersenAdult, 1 Egg StageN2A2Yes2007-11-16NaNNaNNaNNaNNaNNaNNaNAdult not sampled.
4PAL07085Adelie Penguin (Pygoscelis adeliae)AnversTorgersenAdult, 1 Egg StageN3A1Yes2007-11-1636.719.3193.03450.0FEMALE8.76651-25.32426NaN
type(data)
pandas.core.frame.DataFrame
data=data[['Species','Culmen Length (mm)','Culmen Depth (mm)', 'Flipper Length (mm)','Body Mass (g)']]
type(data)
pandas.core.frame.DataFrame
data.head()
SpeciesCulmen Length (mm)Culmen Depth (mm)Flipper Length (mm)Body Mass (g)
0Adelie Penguin (Pygoscelis adeliae)39.118.7181.03750.0
1Adelie Penguin (Pygoscelis adeliae)39.517.4186.03800.0
2Adelie Penguin (Pygoscelis adeliae)40.318.0195.03250.0
3Adelie Penguin (Pygoscelis adeliae)NaNNaNNaNNaN
4Adelie Penguin (Pygoscelis adeliae)36.719.3193.03450.0
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 344 entries, 0 to 343
Data columns (total 5 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Species              344 non-null    object 
 1   Culmen Length (mm)   342 non-null    float64
 2   Culmen Depth (mm)    342 non-null    float64
 3   Flipper Length (mm)  342 non-null    float64
 4   Body Mass (g)        342 non-null    float64
dtypes: float64(4), object(1)
memory usage: 13.6+ KB
data=data.fillna(-1)
data.head()
SpeciesCulmen Length (mm)Culmen Depth (mm)Flipper Length (mm)Body Mass (g)
0Adelie Penguin (Pygoscelis adeliae)39.118.7181.03750.0
1Adelie Penguin (Pygoscelis adeliae)39.517.4186.03800.0
2Adelie Penguin (Pygoscelis adeliae)40.318.0195.03250.0
3Adelie Penguin (Pygoscelis adeliae)-1.0-1.0-1.0-1.0
4Adelie Penguin (Pygoscelis adeliae)36.719.3193.03450.0
data['Species'].unique()
array(['Adelie Penguin (Pygoscelis adeliae)',
       'Gentoo penguin (Pygoscelis papua)',
       'Chinstrap penguin (Pygoscelis antarctica)'], dtype=object)
data['Species'].value_counts()
Adelie Penguin (Pygoscelis adeliae)          152
Gentoo penguin (Pygoscelis papua)            124
Chinstrap penguin (Pygoscelis antarctica)     68
Name: Species, dtype: int64
pd.Series(data['Species']).value_counts()
Adelie Penguin (Pygoscelis adeliae)          152
Gentoo penguin (Pygoscelis papua)            124
Chinstrap penguin (Pygoscelis antarctica)     68
Name: Species, dtype: int64
data.describe()
Culmen Length (mm)Culmen Depth (mm)Flipper Length (mm)Body Mass (g)
count344.000000344.000000344.000000344.000000
mean43.66075617.045640199.7412794177.319767
std6.4289572.40561420.806759861.263227
min-1.000000-1.000000-1.000000-1.000000
25%39.20000015.500000190.0000003550.000000
50%44.25000017.300000197.0000004025.000000
75%48.50000018.700000213.0000004750.000000
max59.60000021.500000231.0000006300.000000

2.4 可视化描述

plt.figure()
sns.pairplot(data=data, diag_kind='hist', hue= 'Species')
plt.show()
<Figure size 432x288 with 0 Axes>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kHfapcTO-1598089259058)(output_35_1.png)]

'''为了方便我们将标签转化为数字 
   'Adelie Penguin (Pygoscelis adeliae)'  ------0 
   'Gentoo penguin (Pygoscelis papua)'  ------1 
   'Chinstrap penguin (Pygoscelis antarctica)   ------2 '''
def trans(x): 
    if x == data['Species'].unique()[0]:   
        return 0 
    if x == data['Species'].unique()[1]:   
        return 1 
    if x == data['Species'].unique()[2]: 
        return 2
    
data['Species'] = data['Species'].apply(trans)
for col in data.columns: 
    if col != 'Species': 
        sns.boxplot(x='Species', y=col, saturation=0.5, palette='pastel', data=data) 
        plt.title(col) 
        plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述在这里插入图片描述

# 选取其前三个特征绘制三维散点图 
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(10,8)) 
ax = fig.add_subplot(111, projection='3d')
data_class0 = data[data['Species']==0].values 
data_class1 = data[data['Species']==1].values 
data_class2 = data[data['Species']==2].values
ax.scatter(data_class0[:,0], data_class0[:,1], data_class0[:,2],label=data['Species'].unique()[0]) 
ax.scatter(data_class1[:,0], data_class1[:,1], data_class1[:,2],label=data['Species'].unique()[1]) 
ax.scatter(data_class2[:,0], data_class2[:,1], data_class2[:,2],label=data['Species'].unique()[2]) 
plt.legend()
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U09SceNb-1598089259070)(output_38_0.png)]

2.5 利用决策树模型在二分类上进行训练和预测

data.columns
Index(['Species', 'Culmen Length (mm)', 'Culmen Depth (mm)',
       'Flipper Length (mm)', 'Body Mass (g)'],
      dtype='object')
data.Species
0      0
1      0
2      0
3      0
4      0
      ..
339    2
340    2
341    2
342    2
343    2
Name: Species, Length: 344, dtype: int64
data['Species']
0      0
1      0
2      0
3      0
4      0
      ..
339    2
340    2
341    2
342    2
343    2
Name: Species, Length: 344, dtype: int64
from sklearn.model_selection import train_test_split
data=data.fillna(-1)
data_target_part=data[data['Species'].isin([0,1])][['Species']]
data_feature_part=data[data['Species'].isin([0,1])][['Culmen Length (mm)','Culmen Depth (mm)', 'Flipper Length (mm)','Body Mass (g)']]
data_target_part.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 276 entries, 0 to 275
Data columns (total 1 columns):
 #   Column   Non-Null Count  Dtype
---  ------   --------------  -----
 0   Species  276 non-null    int64
dtypes: int64(1)
memory usage: 4.3 KB
data_target_part.head()
Species
00
10
20
30
40
import numpy as np
import pandas as pd
data_target_part['Species'].value_counts()
0    152
1    124
Name: Species, dtype: int64
data_feature_part.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 276 entries, 0 to 275
Data columns (total 4 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Culmen Length (mm)   276 non-null    float64
 1   Culmen Depth (mm)    276 non-null    float64
 2   Flipper Length (mm)  276 non-null    float64
 3   Body Mass (g)        276 non-null    float64
dtypes: float64(4)
memory usage: 10.8 KB
x_train,x_test,y_train,y_test=train_test_split(data_feature_part,data_target_part,test_size=0.2, random_state=2020)
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

tree_clf=DecisionTreeClassifier(criterion='entropy')
tree_clf.fit(x_train,y_train)
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='entropy',
                       max_depth=None, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')
# 可视化
import graphviz
dot_data=tree.export_graphviz(tree_clf,out_file=None,feature_names=list(x_train.columns),class_names=str(np.unique(y_train)[0])+str(np.unique(y_train)[1]))
graph=graphviz.Source(dot_data)
graph.render("pingo1",view=True)
'pingo1.pdf'

在这里插入图片描述

train_predict=tree_clf.predict(x_train)
test_predict=tree_clf.predict(x_test)
print("test_predict:\n",test_predict)
test_predict:
 [0 0 1 0 1 0 1 1 0 1 0 1 1 0 1 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0
 0 1 1 0 0 1 1 0 0 0 1 0 0 0 1 0 1 0 0]
from sklearn import metrics
print("x_train accuracy is :",metrics.accuracy_score(y_train,train_predict))
print("x_test accuracy is :",metrics.accuracy_score(y_test,test_predict))
x_train accuracy is : 0.9954545454545455
x_test accuracy is : 1.0
confusion_metrix_result=metrics.confusion_matrix(test_predict,y_test)
print('confusion_metrix_result is :\n',confusion_metrix_result)
confusion_metrix_result is :
 [[31  0]
 [ 0 25]]
plt.figure()
sns.heatmap(confusion_metrix_result, annot=True, cmap='Blues')
plt.xlabel('test_predict')
plt.ylabel("y_test")
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dNcs7vJx-1598089259074)(output_54_0.png)]

2.6 利用决策树模型在三分类(多分类)上进行训练和预测

## 测试集大小为20%, 80%/20%分 
x_train, x_test, y_train, y_test = train_test_split(data[['Culmen Length (mm)','Culmen Depth (mm)',
                                                          'Flipper Length (mm)','Body Mass (g)']], data[['Species']], test_size = 0.2, random_state = 2020)  
clf = DecisionTreeClassifier() 
clf.fit(x_train, y_train)
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=None, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')
## 在训练集和测试集上分布利用训练好的模型进行预测 
train_predict = clf.predict(x_train) 
test_predict = clf.predict(x_test)
#利用 predict_proba 函数预测其概率 
train_predict_proba = clf.predict_proba(x_train) 
test_predict_proba = clf.predict_proba(x_test)
print('The test predict Probability of each class:\n',test_predict_proba) ## 其中第一列代表预测为0类的概率,第二列代表预测为1类的概率,第三列代表预测为2类的概率。
## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果 
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict)) 
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))
The test predict Probability of each class:
 [[0. 0. 1.]
 [0. 1. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [1. 0. 0.]
 [0. 0. 1.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]
 [1. 0. 0.]]
The accuracy of the Logistic Regression is: 0.9963636363636363
The accuracy of the Logistic Regression is: 0.9710144927536232
print(test_predict)
[2 1 1 0 0 2 2 0 1 0 1 1 0 1 1 1 0 1 0 0 2 0 2 0 0 0 1 0 1 0 0 2 2 1 0 1 1
 0 0 1 0 0 1 0 0 2 2 0 0 1 0 0 1 1 2 2 1 0 0 0 1 1 2 2 0 1 2 0 0]
## 查看混淆矩阵 
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test) 
print('The confusion matrix result:\n',confusion_matrix_result)
# 利用热力图对于结果进行可视化 
plt.figure(figsize=(8, 6)) 
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues') 
plt.xlabel('Predicted labels') 
plt.ylabel('True labels') 
plt.show()
The confusion matrix result:
 [[31  1  0]
 [ 0 23  0]
 [ 1  0 13]]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iYWOkfnY-1598089259074)(output_59_1.png)]

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值