sklearn决策树与随机森林 参数及规则提取 模型可视化(初体验)

决策树

import os
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.tree import _tree
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction import DictVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import pydotplus


def tree_to_code(tree, feature_names):    # 决策树规则提取
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print('feature_name:', feature_name)
    with open('code.txt', 'a+') as f:
        f.write("def tree({}):".format(", ".join(feature_names)))
        f.write('\n')
        f.close()

    def recurse(node, depth):
        indent = "  " * depth
        # print('tree_.feature:',tree_.feature)
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            # print('tree_.feature[node]:',tree_.feature[node])
            name = feature_name[node]
            threshold = tree_.threshold[node]
            with open('code.txt', 'a+') as f:
                f.write("{}if {} <= {}:".format(indent, name, threshold))
                f.write('\n')
                f.close()
            recurse(tree_.children_left[node], depth + 1)
            with open('code.txt', 'a+') as f:
                f.write("{}else:  # if {} > {}".format(indent, name, threshold))
                f.write('\n')
                f.close()
            recurse(tree_.children_right[node], depth + 1)
        else:
            with open('code.txt', 'a+') as f:
                f.write("{}return {} -- {}".format(indent, tree_.value[node],
                                                target_name[np.argmax(tree_.value[node])]))
                f.write('\n')
                f.close()

    recurse(0, 1)

pwd = os.getcwd()
titanic = pd.read_csv(pwd + '/ta.txt')
titanic['age'].fillna(titanic['age'].mean(), inplace=True) #  补充缺失值
# 选取一些特征作为我们划分的依据
x = titanic[['pclass', 'age', 'sex']]
y = titanic['survived']
labels = [0, 1]
target_name = ["deid", "survived"]

fea_name = ["sex", "age", "pclass"]
fea_name.sort()

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3)  # 测试数据和训练数据的比例 数值为测数据/总体数据

dt = DictVectorizer(sparse=False)   #  sparse=False意思是不产生稀疏矩阵

x_train = dt.fit_transform(x_train.to_dict(orient="record"))

x_test = dt.fit_transform(x_test.to_dict(orient="record"))

# 使用决策树
dtc = DecisionTreeClassifier(                         # 使用默认的就行
                             # class_weight='balanced',   #  平衡数据集
                             # criterion='entropy',     # 划分标准使用gini还是信息熵  默认gini
                             # max_features='sqrt',   
                             )

dtc.fit(x_train, y_train)

dt_predict = dtc.predict(x_test)


tree_to_code(dtc, fea_name)  # 实现决策树的规则提取

print(dtc.score(x_test, y_test))  

print(classification_report(y_test, dt_predict, labels=labels, target_names=target_name))

# # 混淆矩阵并可视化
confmat = confusion_matrix(y_true=y_test, y_pred=rfc_y_predict, labels=labels)  # 输出混淆矩阵
print(confmat)

fig, ax = plt.subplots(figsize=(3, 3))
ax.matshow(confmat, cmap=plt.cm.Blues, alpha=0.3)
for i in range(confmat.shape[0]):
    for j in range(confmat.shape[1]):
        ax.text(x=j, y=i, s=confmat[i, j], va='center', ha='center')

plt.xticks(range(len(confmat)), labels)
plt.yticks(range(len(confmat)), labels)
plt.xlabel('predicted label')
plt.ylabel('true label')
plt.savefig('confusion_matrix.png')
plt.show()

# 可视化决策树

os.environ["PATH"] += os.pathsep + 'graphviz的bin路径'   #  在pycharm运行时 可能会出现找不到graphviz的情况,自己加环境
dot_data = tree.export_graphviz(dtc, out_file=None, feature_names=fea_name, class_names=target_name,
                                filled=True,
                                rounded=True,
                                )
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("descion_tree.pdf")

随机森林

# 之前的数据导入处理和决策树一样
# 使用随机森林

rfc = RandomForestClassifier(n_estimators=100, max_depth=6)  #  如果不设置n_estimators的值 在2.0版本会有警告提示 建议将其设置为2.02的默认值100

rfc.fit(x_train, y_train)

rfc_y_predict = rfc.predict(x_test)

print(rfc.score(x_test, y_test))

print(classification_report(y_test, rfc_y_predict, labels=labels, target_names=target_name))

if os.path.exists(pwd + '/forest/'):
    os.chdir(pwd + '/forest/')
else:
    os.mkdir(pwd + '/forest/')
    os.chdir(pwd + '/forest/')

for idx, estimator in enumerate(rfc.estimators_):
    # 导出dot文件
    filename = 'forest_' + str(idx) + '.pdf'
    dot_data = tree.export_graphviz(estimator,
                                    out_file=None,
                                    feature_names=fea_name,
                                    class_names=target_name,
                                    rounded=True,
                                    proportion=False,
                                    precision=2,
                                    filled=True)
                                    
    graph = pydotplus.graph_from_dot_data(dot_data)

    graph.write_pdf(filename)

本地文件ta的原始文件
性别 Pclass 分别做了数值处理

提取的规则代码块

def tree(age, pclass, sex):
  if sex <= 1.5:
    if age <= 10.0:
      if pclass <= 2.5:
        return [[ 0. 12.]] deid
      else:  # if pclass > 2.5
        if age <= 0.583299994468689:
          return [[1. 0.]] survived
        else:  # if age > 0.583299994468689
          if age <= 4.0:
            return [[0. 3.]] deid
          else:  # if age > 4.0
            if age <= 7.5:
              return [[2. 0.]] survived
            else:  # if age > 7.5
              return [[1. 2.]] deid
    else:  # if age > 10.0
      if pclass <= 1.5:
        if age <= 54.5:
          if age <= 29.0:
            if age <= 17.5:
              return [[0. 2.]] deid
            else:  # if age > 17.5
              if age <= 24.5:
                if age <= 20.0:
                  return [[2. 0.]] survived
                else:  # if age > 20.0
                  if age <= 23.5:
                    if age <= 21.5:
                      return [[0. 1.]] deid
                    else:  # if age > 21.5
                      if age <= 22.5:
                        return [[1. 0.]] survived
                      else:  # if age > 22.5
                        return [[0. 1.]] deid
                  else:  # if age > 23.5
                    return [[2. 0.]] survived
              else:  # if age > 24.5
                if age <= 26.0:
                  return [[1. 2.]] deid
                else:  # if age > 26.0
                  if age <= 27.5:
                    return [[0. 1.]] deid
                  else:  # if age > 27.5
                    return [[1. 2.]] deid
          else:  # if age > 29.0
            if age <= 33.5:
              if age <= 31.09709072113037:
                return [[4. 0.]] survived
              else:  # if age > 31.09709072113037
                if age <= 32.09709072113037:
                  return [[29. 10.]] survived
                else:  # if age > 32.09709072113037
                  return [[2. 0.]] survived
            else:  # if age > 33.5
              if age <= 36.5:
                if age <= 35.5:
                  return [[0. 2.]] deid
                else:  # if age > 35.5
                  return [[1. 4.]] deid
              else:  # if age > 36.5
                if age <= 47.5:
                  if age <= 38.5:
                    if age <= 37.5:
                      return [[1. 1.]] survived
                    else:  # if age > 37.5
                      return [[1. 1.]] survived
                  else:  # if age > 38.5
                    if age <= 45.5:
                      if age <= 41.5:
                        if age <= 39.5:
                          return [[3. 1.]] survived
                        else:  # if age > 39.5
                          return [[2. 0.]] survived
                      else:  # if age > 41.5
                        if age <= 43.0:
                          return [[2. 1.]] survived
                        else:  # if age > 43.0
                          if age <= 44.5:
                            return [[1. 0.]] survived
                          else:  # if age > 44.5
                            return [[3. 1.]] survived
                    else:  # if age > 45.5
                      if age <= 46.5:
                        return [[5. 0.]] survived
                      else:  # if age > 46.5
                        return [[3. 1.]] survived
                else:  # if age > 47.5
                  if age <= 48.5:
                    return [[1. 2.]] deid
                  else:  # if age > 48.5
                    if age <= 51.5:
                      if age <= 49.5:
                        return [[2. 1.]] survived
                      else:  # if age > 49.5
                        return [[3. 0.]] survived
                    else:  # if age > 51.5
                      if age <= 53.0:
                        return [[1. 1.]] survived
                      else:  # if age > 53.0
                        return [[1. 1.]] survived
        else:  # if age > 54.5
          return [[14.  0.]] survived
      else:  # if pclass > 1.5
        if age <= 29.5:
          if age <= 25.5:
            if age <= 23.5:
              if age <= 18.5:
                return [[17.  0.]] survived
              else:  # if age > 18.5
                if age <= 19.5:
                  if pclass <= 2.5:
                    return [[1. 0.]] survived
                  else:  # if pclass > 2.5
                    return [[4. 1.]] survived
                else:  # if age > 19.5
                  if age <= 20.5:
                    return [[8. 0.]] survived
                  else:  # if age > 20.5
                    if age <= 22.5:
                      if age <= 21.5:
                        if pclass <= 2.5:
                          return [[5. 0.]] survived
                        else:  # if pclass > 2.5
                          return [[4. 1.]] survived
                      else:  # if age > 21.5
                        if pclass <= 2.5:
                          return [[3. 1.]] survived
                        else:  # if pclass > 2.5
                          return [[3. 0.]] survived
                    else:  # if age > 22.5
                      return [[7. 0.]] survived
            else:  # if age > 23.5
              if age <= 24.5:
                if pclass <= 2.5:
                  return [[1. 1.]] survived
                else:  # if pclass > 2.5
                  return [[6. 1.]] survived
              else:  # if age > 24.5
                if pclass <= 2.5:
                  return [[4. 0.]] survived
                else:  # if pclass > 2.5
                  return [[4. 1.]] survived
          else:  # if age > 25.5
            return [[23.  0.]] survived
        else:  # if age > 29.5
          if age <= 45.5:
            if age <= 44.5:
              if age <= 32.5:
                if age <= 31.59709072113037:
                  if pclass <= 2.5:
                    if age <= 30.59709072113037:
                      return [[8. 0.]] survived
                    else:  # if age > 30.59709072113037
                      return [[32.  4.]] survived
                  else:  # if pclass > 2.5
                    if age <= 30.59709072113037:
                      return [[1. 1.]] survived
                    else:  # if age > 30.59709072113037
                      return [[220.  32.]] survived
                else:  # if age > 31.59709072113037
                  if pclass <= 2.5:
                    return [[3. 2.]] survived
                  else:  # if pclass > 2.5
                    return [[5. 0.]] survived
              else:  # if age > 32.5
                if age <= 35.5:
                  return [[11.  0.]] survived
                else:  # if age > 35.5
                  if age <= 36.5:
                    if pclass <= 2.5:
                      return [[1. 0.]] survived
                    else:  # if pclass > 2.5
                      return [[0. 1.]] deid
                  else:  # if age > 36.5
                    if pclass <= 2.5:
                      if age <= 40.5:
                        return [[3. 0.]] survived
                      else:  # if age > 40.5
                        if age <= 41.5:
                          return [[1. 1.]] survived
                        else:  # if age > 41.5
                          return [[3. 0.]] survived
                    else:  # if pclass > 2.5
                      return [[11.  0.]] survived
            else:  # if age > 44.5
              if pclass <= 2.5:
                return [[2. 0.]] survived
              else:  # if pclass > 2.5
                return [[1. 1.]] survived
          else:  # if age > 45.5
            return [[13.  0.]] survived
  else:  # if sex > 1.5
    if pclass <= 2.5:
      if pclass <= 1.5:
        if age <= 62.5:
          if age <= 36.5:
            if age <= 35.5:
              if age <= 24.5:
                return [[ 0. 19.]] deid
              else:  # if age > 24.5
                if age <= 26.0:
                  return [[1. 0.]] survived
                else:  # if age > 26.0
                  if age <= 31.09709072113037:
                    return [[0. 6.]] deid
                  else:  # if age > 31.09709072113037
                    if age <= 32.09709072113037:
                      return [[ 1. 23.]] deid
                    else:  # if age > 32.09709072113037
                      return [[0. 5.]] deid
            else:  # if age > 35.5
              return [[1. 3.]] deid
          else:  # if age > 36.5
            return [[ 0. 31.]] deid
        else:  # if age > 62.5
          if age <= 63.5:
            return [[1. 1.]] survived
          else:  # if age > 63.5
            return [[0. 1.]] deid
      else:  # if pclass > 1.5
        if age <= 17.5:
          return [[0. 9.]] deid
        else:  # if age > 17.5
          if age <= 22.5:
            if age <= 21.5:
              if age <= 18.5:
                return [[1. 3.]] deid
              else:  # if age > 18.5
                return [[0. 5.]] deid
            else:  # if age > 21.5
              return [[2. 0.]] survived
          else:  # if age > 22.5
            if age <= 26.5:
              return [[0. 5.]] deid
            else:  # if age > 26.5
              if age <= 27.5:
                return [[1. 1.]] survived
              else:  # if age > 27.5
                if age <= 29.5:
                  return [[0. 5.]] deid
                else:  # if age > 29.5
                  if age <= 30.5:
                    return [[1. 2.]] deid
                  else:  # if age > 30.5
                    if age <= 46.0:
                      if age <= 43.0:
                        if age <= 39.0:
                          if age <= 37.0:
                            if age <= 31.59709072113037:
                              if age <= 31.09709072113037:
                                return [[0. 2.]] deid
                              else:  # if age > 31.09709072113037
                                return [[ 3. 17.]] deid
                            else:  # if age > 31.59709072113037
                              return [[0. 9.]] deid
                          else:  # if age > 37.0
                            return [[1. 0.]] survived
                        else:  # if age > 39.0
                          return [[0. 4.]] deid
                      else:  # if age > 43.0
                        return [[1. 0.]] survived
                    else:  # if age > 46.0
                      return [[0. 5.]] deid
    else:  # if pclass > 2.5
      if age <= 19.5:
        if age <= 12.0:
          if age <= 5.5:
            if age <= 1.0833500027656555:
              return [[0. 1.]] deid
            else:  # if age > 1.0833500027656555
              if age <= 3.5:
                return [[1. 0.]] survived
              else:  # if age > 3.5
                return [[0. 1.]] deid
          else:  # if age > 5.5
            return [[2. 0.]] survived
        else:  # if age > 12.0
          if age <= 17.5:
            if age <= 15.5:
              return [[0. 1.]] deid
            else:  # if age > 15.5
              if age <= 16.5:
                return [[1. 3.]] deid
              else:  # if age > 16.5
                return [[0. 1.]] deid
          else:  # if age > 17.5
            if age <= 18.5:
              return [[2. 3.]] deid
            else:  # if age > 18.5
              return [[0. 1.]] deid
      else:  # if age > 19.5
        if age <= 21.5:
          return [[3. 0.]] survived
        else:  # if age > 21.5
          if age <= 23.5:
            if age <= 22.5:
              return [[1. 2.]] deid
            else:  # if age > 22.5
              return [[0. 1.]] deid
          else:  # if age > 23.5
            if age <= 32.5:
              if age <= 31.59709072113037:
                if age <= 25.5:
                  return [[1. 1.]] survived
                else:  # if age > 25.5
                  if age <= 29.0:
                    return [[2. 0.]] survived
                  else:  # if age > 29.0
                    if age <= 30.59709072113037:
                      return [[1. 1.]] survived
                    else:  # if age > 30.59709072113037
                      return [[75. 40.]] survived
              else:  # if age > 31.59709072113037
                return [[1. 0.]] survived
            else:  # if age > 32.5
              if age <= 37.0:
                return [[0. 3.]] deid
              else:  # if age > 37.0
                if age <= 42.5:
                  return [[2. 0.]] survived
                else:  # if age > 42.5
                  return [[1. 1.]] survived

混淆矩阵

结果图
在这里插入图片描述

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值