决策树实践

数据集:

8.iris.data

5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

10.1.Iris_DecisionTree.py

#!/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline


def iris_type(s):
    it = {b'Iris-setosa': 0,
          b'Iris-versicolor': 1,
          b'Iris-virginica': 2}
    return it[s]


# 花萼长度、花萼宽度,花瓣长度,花瓣宽度
# iris_feature = 'sepal length', 'sepal width', 'petal length', 'petal width'
iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'

if __name__ == "__main__":
    mpl.rcParams['font.sans-serif'] = [u'SimHei']
    mpl.rcParams['axes.unicode_minus'] = False

    path = '8.iris.data'  # 数据文件路径
    data = np.loadtxt(path, dtype=float, delimiter=',', converters={4: iris_type})
    x, y = np.split(data, (4,), axis=1)
    # 为了可视化,仅使用前两列特征
    x = x[:, :2]
    print(x)
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1)
    #ss = StandardScaler()
    #ss = ss.fit(x_train)

    # 决策树参数估计
    # min_samples_split = 10:如果该结点包含的样本数目大于10,则(有可能)对其分支
    # min_samples_leaf = 10:若将某结点分支后,得到的每个子结点样本数目都大于10,则完成分支;否则,不进行分支
    model = Pipeline([
        ('ss', StandardScaler()),
        ('DTC', DecisionTreeClassifier(criterion='entropy', max_depth=3))])
    # clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)
    model = model.fit(x_train, y_train)
    y_test_hat = model.predict(x_test)      # 测试数据

    # 保存
    # dot -Tpng -o 1.png 1.dot
    f = open('.\\iris_tree.dot', 'w')
    tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f)

    # 画图
    N, M = 100, 100  # 横纵各采样多少个值
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()  # 第0列的范围
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()  # 第1列的范围
    t1 = np.linspace(x1_min, x1_max, N)
    t2 = np.linspace(x2_min, x2_max, M)
    x1, x2 = np.meshgrid(t1, t2)  # 生成 v 网格采样点
    x_show = np.stack((x1.flat, x2.flat), axis=1)  # 测试点

    # # 无意义,只是为了凑另外两个维度
    # # 打开该注释前,确保注释掉x = x[:, :2]
    # x3 = np.ones(x1.size) * np.average(x[:, 2])
    # x4 = np.ones(x1.size) * np.average(x[:, 3])
    # x_test = np.stack((x1.flat, x2.flat, x3, x4), axis=1)  # 测试点

    cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
    cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
    y_show_hat = model.predict(x_show)  # 预测值
    y_show_hat = y_show_hat.reshape(x1.shape)  # 使之与输入的形状相同
    plt.figure(facecolor='w')
    plt.pcolormesh(x1, x2, y_show_hat, cmap=cm_light)  # 预测值的显示
    plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test.ravel(), edgecolors='k', s=100, cmap=cm_dark, marker='o')  # 测试数据
    plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=40, cmap=cm_dark)  # 全部数据
    plt.xlabel(iris_feature[0], fontsize=15)
    plt.ylabel(iris_feature[1], fontsize=15)
    plt.xlim(x1_min, x1_max)
    plt.ylim(x2_min, x2_max)
    plt.grid(True)
    plt.title(u'鸢尾花数据的决策树分类', fontsize=17)
    plt.show()

    # 训练集上的预测结果
    y_test = y_test.reshape(-1)
    print(y_test_hat)
    print(y_test)
    result = (y_test_hat == y_test)   # True则预测正确,False则预测错误
    acc = np.mean(result)
    print('准确度: %.2f%%' % (100 * acc))

    # 过拟合:错误率
    depth = np.arange(1, 15)
    err_list = []
    for d in depth:
        clf = DecisionTreeClassifier(criterion='entropy', max_depth=d)
        clf = clf.fit(x_train, y_train)
        y_test_hat = clf.predict(x_test)  # 测试数据
        result = (y_test_hat == y_test)  # True则预测正确,False则预测错误
        err = 1 - np.mean(result)
        err_list.append(err)
        print(d, ' 准确度: %.2f%%' % (100 * err))
    plt.figure(facecolor='w')
    plt.plot(depth, err_list, 'ro-', lw=2)
    plt.xlabel(u'决策树深度', fontsize=15)
    plt.ylabel(u'错误率', fontsize=15)
    plt.title(u'决策树深度与过拟合', fontsize=17)
    plt.grid(True)
    plt.show()

结果:

E:\pythonwork\venv\Scripts\python.exe E:/pythonspace/10.RandomForest/10.1.Iris_DecisionTree.py
[[5.1 3.5]
 [4.9 3. ]
 [4.7 3.2]
 [4.6 3.1]
 [5.  3.6]
 [5.4 3.9]
 [4.6 3.4]
 [5.  3.4]
 [4.4 2.9]
 [4.9 3.1]
 [5.4 3.7]
 [4.8 3.4]
 [4.8 3. ]
 [4.3 3. ]
 [5.8 4. ]
 [5.7 4.4]
 [5.4 3.9]
 [5.1 3.5]
 [5.7 3.8]
 [5.1 3.8]
 [5.4 3.4]
 [5.1 3.7]
 [4.6 3.6]
 [5.1 3.3]
 [4.8 3.4]
 [5.  3. ]
 [5.  3.4]
 [5.2 3.5]
 [5.2 3.4]
 [4.7 3.2]
 [4.8 3.1]
 [5.4 3.4]
 [5.2 4.1]
 [5.5 4.2]
 [4.9 3.1]
 [5.  3.2]
 [5.5 3.5]
 [4.9 3.1]
 [4.4 3. ]
 [5.1 3.4]
 [5.  3.5]
 [4.5 2.3]
 [4.4 3.2]
 [5.  3.5]
 [5.1 3.8]
 [4.8 3. ]
 [5.1 3.8]
 [4.6 3.2]
 [5.3 3.7]
 [5.  3.3]
 [7.  3.2]
 [6.4 3.2]
 [6.9 3.1]
 [5.5 2.3]
 [6.5 2.8]
 [5.7 2.8]
 [6.3 3.3]
 [4.9 2.4]
 [6.6 2.9]
 [5.2 2.7]
 [5.  2. ]
 [5.9 3. ]
 [6.  2.2]
 [6.1 2.9]
 [5.6 2.9]
 [6.7 3.1]
 [5.6 3. ]
 [5.8 2.7]
 [6.2 2.2]
 [5.6 2.5]
 [5.9 3.2]
 [6.1 2.8]
 [6.3 2.5]
 [6.1 2.8]
 [6.4 2.9]
 [6.6 3. ]
 [6.8 2.8]
 [6.7 3. ]
 [6.  2.9]
 [5.7 2.6]
 [5.5 2.4]
 [5.5 2.4]
 [5.8 2.7]
 [6.  2.7]
 [5.4 3. ]
 [6.  3.4]
 [6.7 3.1]
 [6.3 2.3]
 [5.6 3. ]
 [5.5 2.5]
 [5.5 2.6]
 [6.1 3. ]
 [5.8 2.6]
 [5.  2.3]
 [5.6 2.7]
 [5.7 3. ]
 [5.7 2.9]
 [6.2 2.9]
 [5.1 2.5]
 [5.7 2.8]
 [6.3 3.3]
 [5.8 2.7]
 [7.1 3. ]
 [6.3 2.9]
 [6.5 3. ]
 [7.6 3. ]
 [4.9 2.5]
 [7.3 2.9]
 [6.7 2.5]
 [7.2 3.6]
 [6.5 3.2]
 [6.4 2.7]
 [6.8 3. ]
 [5.7 2.5]
 [5.8 2.8]
 [6.4 3.2]
 [6.5 3. ]
 [7.7 3.8]
 [7.7 2.6]
 [6.  2.2]
 [6.9 3.2]
 [5.6 2.8]
 [7.7 2.8]
 [6.3 2.7]
 [6.7 3.3]
 [7.2 3.2]
 [6.2 2.8]
 [6.1 3. ]
 [6.4 2.8]
 [7.2 3. ]
 [7.4 2.8]
 [7.9 3.8]
 [6.4 2.8]
 [6.3 2.8]
 [6.1 2.6]
 [7.7 3. ]
 [6.3 3.4]
 [6.4 3.1]
 [6.  3. ]
 [6.9 3.1]
 [6.7 3.1]
 [6.9 3.1]
 [5.8 2.7]
 [6.8 3.2]
 [6.7 3.3]
 [6.7 3. ]
 [6.3 2.5]
 [6.5 3. ]
 [6.2 3.4]
 [5.9 3. ]]
[0. 1. 2. 0. 2. 2. 2. 0. 0. 2. 1. 0. 2. 2. 1. 0. 1. 1. 0. 0. 1. 0. 2. 0.
 2. 1. 0. 0. 1. 2. 1. 2. 1. 2. 1. 0. 1. 0. 2. 2. 2. 0. 1. 2. 2.]
[0. 1. 1. 0. 2. 1. 2. 0. 0. 2. 1. 0. 2. 1. 1. 0. 1. 1. 0. 0. 1. 1. 1. 0.
 2. 1. 0. 0. 1. 2. 1. 2. 1. 2. 2. 0. 1. 0. 1. 2. 2. 0. 2. 2. 1.]
准确度: 80.00%
1  准确度: 44.44%
2  准确度: 40.00%
3  准确度: 20.00%
4  准确度: 24.44%
5  准确度: 24.44%
6  准确度: 26.67%
7  准确度: 35.56%
8  准确度: 40.00%
9  准确度: 35.56%
10  准确度: 40.00%
11  准确度: 37.78%
12  准确度: 37.78%
13  准确度: 40.00%
14  准确度: 37.78%

Process finished with exit code 0

 

 

 10.2.Iris_DecisionTree_Enum.py

#!/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.tree import DecisionTreeClassifier


def iris_type(s):
    it = {b'Iris-setosa': 0, b'Iris-versicolor': 1, b'Iris-virginica': 2}
    return it[s]

# 'sepal length', 'sepal width', 'petal length', 'petal width'
iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'

if __name__ == "__main__":
    mpl.rcParams['font.sans-serif'] = [u'SimHei']  # 黑体 FangSong/KaiTi
    mpl.rcParams['axes.unicode_minus'] = False

    path = '8.iris.data'  # 数据文件路径
    data = np.loadtxt(path, dtype=float, delimiter=',', converters={4: iris_type})
    x_prime, y = np.split(data, (4,), axis=1)

    feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
    plt.figure(figsize=(10, 9), facecolor='#FFFFFF')
    for i, pair in enumerate(feature_pairs):
        # 准备数据
        x = x_prime[:, pair]

        # 决策树学习
        clf = DecisionTreeClassifier(criterion='entropy', min_samples_leaf=3)
        dt_clf = clf.fit(x, y)

        # 画图
        N, M = 500, 500  # 横纵各采样多少个值
        x1_min, x1_max = x[:, 0].min(), x[:, 0].max()  # 第0列的范围
        x2_min, x2_max = x[:, 1].min(), x[:, 1].max()  # 第1列的范围
        t1 = np.linspace(x1_min, x1_max, N)
        t2 = np.linspace(x2_min, x2_max, M)
        x1, x2 = np.meshgrid(t1, t2)  # 生成网格采样点
        x_test = np.stack((x1.flat, x2.flat), axis=1)  # 测试点

        # 训练集上的预测结果
        y_hat = dt_clf.predict(x)
        y = y.reshape(-1)
        c = np.count_nonzero(y_hat == y)    # 统计预测正确的个数
        print('特征:  ', iris_feature[pair[0]], ' + ', iris_feature[pair[1]],)
        print('\t预测正确数目:', c,)
        print('\t准确率: %.2f%%' % (100 * float(c) / float(len(y))))

        # 显示
        cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
        cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
        y_hat = dt_clf.predict(x_test)  # 预测值
        y_hat = y_hat.reshape(x1.shape)  # 使之与输入的形状相同
        plt.subplot(2, 3, i+1)
        plt.pcolormesh(x1, x2, y_hat, cmap=cm_light)  # 预测值
        plt.scatter(x[:, 0], x[:, 1], c=y, edgecolors='k', cmap=cm_dark)  # 样本
        plt.xlabel(iris_feature[pair[0]], fontsize=14)
        plt.ylabel(iris_feature[pair[1]], fontsize=14)
        plt.xlim(x1_min, x1_max)
        plt.ylim(x2_min, x2_max)
        plt.grid()
    plt.suptitle(u'决策树对鸢尾花数据的两特征组合的分类结果', fontsize=18)
    plt.tight_layout(2)
    plt.subplots_adjust(top=0.92)
    plt.show()

结果:

E:\pythonwork\venv\Scripts\python.exe E:/pythonspace/10.RandomForest/10.2.Iris_DecisionTree_Enum.py
特征:   花萼长度  +  花萼宽度
	预测正确数目: 123
	准确率: 82.00%
特征:   花萼长度  +  花瓣长度
	预测正确数目: 145
	准确率: 96.67%
特征:   花萼长度  +  花瓣宽度
	预测正确数目: 144
	准确率: 96.00%
特征:   花萼宽度  +  花瓣长度
	预测正确数目: 143
	准确率: 95.33%
特征:   花萼宽度  +  花瓣宽度
	预测正确数目: 145
	准确率: 96.67%
特征:   花瓣长度  +  花瓣宽度
	预测正确数目: 147
	准确率: 98.00%

Process finished with exit code 0

 

 10.3.DecisionTreeRegressor.py

#!/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor


if __name__ == "__main__":
    N = 100
    x = np.random.rand(N) * 6 - 3     # [-3,3)
    x.sort()
    y = np.sin(x) + np.random.randn(N) * 0.05
    print(y)
    x = x.reshape(-1, 1)  # 转置后,得到N个样本,每个样本都是1维的
    print(x)

    reg = DecisionTreeRegressor(criterion='mse', max_depth=9)
    dt = reg.fit(x, y)
    x_test = np.linspace(-3, 3, 50).reshape(-1, 1)
    y_hat = dt.predict(x_test)
    plt.plot(x, y, 'r*', linewidth=2, label='Actual')
    plt.plot(x_test, y_hat, 'g-', linewidth=2, label='Predict')
    plt.legend(loc='upper left')
    plt.grid()
    plt.show()

    # 比较决策树的深度影响
    depth = [2, 4, 6, 8, 10]
    clr = 'rgbmy'
    reg = [DecisionTreeRegressor(criterion='mse', max_depth=depth[0]),
           DecisionTreeRegressor(criterion='mse', max_depth=depth[1]),
           DecisionTreeRegressor(criterion='mse', max_depth=depth[2]),
           DecisionTreeRegressor(criterion='mse', max_depth=depth[3]),
           DecisionTreeRegressor(criterion='mse', max_depth=depth[4])]

    plt.plot(x, y, 'k^', linewidth=2, label='Actual')
    x_test = np.linspace(-3, 3, 50).reshape(-1, 1)
    for i, r in enumerate(reg):
        dt = r.fit(x, y)
        y_hat = dt.predict(x_test)
        plt.plot(x_test, y_hat, '-', color=clr[i], linewidth=2, label='Depth=%d' % depth[i])
    plt.legend(loc='upper left')
    plt.grid()
    plt.show()

结果:

E:\pythonwork\venv\Scripts\python.exe E:/pythonspace/10.RandomForest/10.3.DecisionTreeRegressor.py
[-0.08762283 -0.16771537 -0.23860516 -0.38278713 -0.40841552 -0.48093607
 -0.56326155 -0.54950973 -0.66609012 -0.71334252 -0.64798177 -0.65805192
 -0.74970801 -0.70198992 -0.78982299 -0.84627162 -0.90004589 -0.75272558
 -0.83448298 -0.83091698 -0.93088584 -0.88510208 -0.90669696 -0.92614916
 -0.9927603  -0.92036213 -0.97683139 -1.02773104 -0.93678639 -1.01801533
 -1.01569934 -0.98660534 -0.98409135 -1.09264675 -1.00724295 -1.05969848
 -1.03651372 -0.94403054 -0.87392559 -0.90135942 -0.95865273 -0.8578595
 -0.90945654 -0.92196075 -0.81153918 -0.93168369 -0.77832897 -0.76346297
 -0.68488173 -0.61783284 -0.64327791 -0.58193051 -0.35848254 -0.32665048
 -0.32879428 -0.26418869 -0.18945475 -0.12244273 -0.03867582  0.11229909
  0.24854313  0.59109719  0.75895249  0.708417    0.83726958  0.78888503
  0.90070104  0.98627956  0.86199038  0.94321235  0.95233268  0.96197201
  0.9828884   0.89584501  0.92839125  0.94633715  0.95625243  1.07204747
  0.96827918  0.8980396   0.89179345  0.78041181  0.7454409   0.76607445
  0.59841426  0.70471736  0.67932723  0.63846437  0.76913785  0.658829
  0.60404674  0.62633004  0.53566953  0.64420922  0.56715467  0.4657827
  0.39855462  0.36194898  0.40706424  0.15147884]
[[-2.97990265]
 [-2.97337443]
 [-2.95321606]
 [-2.77951944]
 [-2.75069817]
 [-2.64802897]
 [-2.51622648]
 [-2.47450801]
 [-2.40686662]
 [-2.38829177]
 [-2.35118118]
 [-2.32818686]
 [-2.32724123]
 [-2.27010567]
 [-2.26636734]
 [-2.23939071]
 [-2.2191937 ]
 [-2.19141504]
 [-2.18275017]
 [-2.08399601]
 [-2.07001052]
 [-2.05862414]
 [-2.00281178]
 [-1.99638772]
 [-1.97154131]
 [-1.89581255]
 [-1.89194748]
 [-1.7566021 ]
 [-1.735978  ]
 [-1.71384172]
 [-1.59511853]
 [-1.5843659 ]
 [-1.56736464]
 [-1.53189255]
 [-1.50034428]
 [-1.47751661]
 [-1.46088067]
 [-1.46071084]
 [-1.41078229]
 [-1.3144557 ]
 [-1.28915245]
 [-1.13610951]
 [-1.11582631]
 [-1.02311451]
 [-0.99076348]
 [-0.98395691]
 [-0.92780808]
 [-0.83087654]
 [-0.77773532]
 [-0.67825078]
 [-0.66767348]
 [-0.60995641]
 [-0.43246033]
 [-0.24100792]
 [-0.23800074]
 [-0.2252397 ]
 [-0.2160786 ]
 [-0.11953865]
 [-0.08764956]
 [ 0.07279476]
 [ 0.25360206]
 [ 0.62929412]
 [ 0.8148173 ]
 [ 0.87717764]
 [ 0.94656319]
 [ 0.95059124]
 [ 1.14393197]
 [ 1.24160707]
 [ 1.26547816]
 [ 1.36455416]
 [ 1.54873863]
 [ 1.5678806 ]
 [ 1.57543481]
 [ 1.57789541]
 [ 1.65970316]
 [ 1.74724827]
 [ 1.78461289]
 [ 1.86563055]
 [ 1.87887338]
 [ 1.95357623]
 [ 1.97593249]
 [ 2.16488433]
 [ 2.30660908]
 [ 2.31776991]
 [ 2.3268286 ]
 [ 2.34399191]
 [ 2.37361178]
 [ 2.40385108]
 [ 2.41127997]
 [ 2.42744508]
 [ 2.45560025]
 [ 2.46993779]
 [ 2.50433988]
 [ 2.52071851]
 [ 2.60606967]
 [ 2.66188751]
 [ 2.66838645]
 [ 2.76367578]
 [ 2.77732631]
 [ 2.95712184]]

Process finished with exit code 0

 

10.4.MultiOutput_DTR.py

#!/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor

if __name__ == "__main__":
    N = 300
    x = np.random.rand(N) * 8 - 4     # [-4,4)
    x.sort()
    y1 = np.sin(x) + 3 + np.random.randn(N) * 0.1
    y2 = np.cos(0.3*x) + np.random.randn(N) * 0.01
    # y1 = np.sin(x) + np.random.randn(N) * 0.05
    # y2 = np.cos(x) + np.random.randn(N) * 0.1
    y = np.vstack((y1, y2))
    y = np.vstack((y1, y2)).T
    x = x.reshape(-1, 1)  # 转置后,得到N个样本,每个样本都是1维的

    deep = 3
    reg = DecisionTreeRegressor(criterion='mse', max_depth=deep)
    dt = reg.fit(x, y)

    x_test = np.linspace(-4, 4, num=1000).reshape(-1, 1)
    print(x_test)
    y_hat = dt.predict(x_test)
    print(y_hat)
    plt.scatter(y[:, 0], y[:, 1], c='r', s=40, label='Actual')
    plt.scatter(y_hat[:, 0], y_hat[:, 1], c='g', marker='s', s=100, label='Depth=%d' % deep, alpha=1)
    plt.legend(loc='upper left')
    plt.xlabel('y1')
    plt.ylabel('y2')
    plt.grid()
    plt.show()

 结果:

E:\pythonwork\venv\Scripts\python.exe E:/pythonspace/10.RandomForest/10.4.MultiOutput_DTR.py
[[-4.        ]
 [-3.99199199]
 [-3.98398398]
 [-3.97597598]
 [-3.96796797]
 [-3.95995996]
 [-3.95195195]
 [-3.94394394]
 [-3.93593594]
 [-3.92792793]
 [-3.91991992]
 [-3.91191191]
 [-3.9039039 ]
 [-3.8958959 ]
 [-3.88788789]
 [-3.87987988]
 [-3.87187187]
 [-3.86386386]
 [-3.85585586]
 [-3.84784785]
 [-3.83983984]
 [-3.83183183]
 [-3.82382382]
 [-3.81581582]
 [-3.80780781]
 [-3.7997998 ]
 [-3.79179179]
 [-3.78378378]
 [-3.77577578]
 [-3.76776777]
 [-3.75975976]
 [-3.75175175]
 [-3.74374374]
 [-3.73573574]
 [-3.72772773]
 [-3.71971972]
 [-3.71171171]
 [-3.7037037 ]
 [-3.6956957 ]
 [-3.68768769]
 [-3.67967968]
 [-3.67167167]
 [-3.66366366]
 [-3.65565566]
 [-3.64764765]
 [-3.63963964]
 [-3.63163163]
 [-3.62362362]
 [-3.61561562]
 [-3.60760761]
 [-3.5995996 ]
 [-3.59159159]
 [-3.58358358]
 [-3.57557558]
 [-3.56756757]
 [-3.55955956]
 [-3.55155155]
 [-3.54354354]
 [-3.53553554]
 [-3.52752753]
 [-3.51951952]
 [-3.51151151]
 [-3.5035035 ]
 [-3.4954955 ]
 [-3.48748749]
 [-3.47947948]
 [-3.47147147]
 [-3.46346346]
 [-3.45545546]
 [-3.44744745]
 [-3.43943944]
 [-3.43143143]
 [-3.42342342]
 [-3.41541542]
 [-3.40740741]
 [-3.3993994 ]
 [-3.39139139]
 [-3.38338338]
 [-3.37537538]
 [-3.36736737]
 [-3.35935936]
 [-3.35135135]
 [-3.34334334]
 [-3.33533534]
 [-3.32732733]
 [-3.31931932]
 [-3.31131131]
 [-3.3033033 ]
 [-3.2952953 ]
 [-3.28728729]
 [-3.27927928]
 [-3.27127127]
 [-3.26326326]
 [-3.25525526]
 [-3.24724725]
 [-3.23923924]
 [-3.23123123]
 [-3.22322322]
 [-3.21521522]
 [-3.20720721]
 [-3.1991992 ]
 [-3.19119119]
 [-3.18318318]
 [-3.17517518]
 [-3.16716717]
 [-3.15915916]
 [-3.15115115]
 [-3.14314314]
 [-3.13513514]
 [-3.12712713]
 [-3.11911912]
 [-3.11111111]
 [-3.1031031 ]
 [-3.0950951 ]
 [-3.08708709]
 [-3.07907908]
 [-3.07107107]
 [-3.06306306]
 [-3.05505506]
 [-3.04704705]
 [-3.03903904]
 [-3.03103103]
 [-3.02302302]
 [-3.01501502]
 [-3.00700701]
 [-2.998999  ]
 [-2.99099099]
 [-2.98298298]
 [-2.97497497]
 [-2.96696697]
 [-2.95895896]
 [-2.95095095]
 [-2.94294294]
 [-2.93493493]
 [-2.92692693]
 [-2.91891892]
 [-2.91091091]
 [-2.9029029 ]
 [-2.89489489]
 [-2.88688689]
 [-2.87887888]
 [-2.87087087]
 [-2.86286286]
 [-2.85485485]
 [-2.84684685]
 [-2.83883884]
 [-2.83083083]
 [-2.82282282]
 [-2.81481481]
 [-2.80680681]
 [-2.7987988 ]
 [-2.79079079]
 [-2.78278278]
 [-2.77477477]
 [-2.76676677]
 [-2.75875876]
 [-2.75075075]
 [-2.74274274]
 [-2.73473473]
 [-2.72672673]
 [-2.71871872]
 [-2.71071071]
 [-2.7027027 ]
 [-2.69469469]
 [-2.68668669]
 [-2.67867868]
 [-2.67067067]
 [-2.66266266]
 [-2.65465465]
 [-2.64664665]
 [-2.63863864]
 [-2.63063063]
 [-2.62262262]
 [-2.61461461]
 [-2.60660661]
 [-2.5985986 ]
 [-2.59059059]
 [-2.58258258]
 [-2.57457457]
 [-2.56656657]
 [-2.55855856]
 [-2.55055055]
 [-2.54254254]
 [-2.53453453]
 [-2.52652653]
 [-2.51851852]
 [-2.51051051]
 [-2.5025025 ]
 [-2.49449449]
 [-2.48648649]
 [-2.47847848]
 [-2.47047047]
 [-2.46246246]
 [-2.45445445]
 [-2.44644645]
 [-2.43843844]
 [-2.43043043]
 [-2.42242242]
 [-2.41441441]
 [-2.40640641]
 [-2.3983984 ]
 [-2.39039039]
 [-2.38238238]
 [-2.37437437]
 [-2.36636637]
 [-2.35835836]
 [-2.35035035]
 [-2.34234234]
 [-2.33433433]
 [-2.32632633]
 [-2.31831832]
 [-2.31031031]
 [-2.3023023 ]
 [-2.29429429]
 [-2.28628629]
 [-2.27827828]
 [-2.27027027]
 [-2.26226226]
 [-2.25425425]
 [-2.24624625]
 [-2.23823824]
 [-2.23023023]
 [-2.22222222]
 [-2.21421421]
 [-2.20620621]
 [-2.1981982 ]
 [-2.19019019]
 [-2.18218218]
 [-2.17417417]
 [-2.16616617]
 [-2.15815816]
 [-2.15015015]
 [-2.14214214]
 [-2.13413413]
 [-2.12612613]
 [-2.11811812]
 [-2.11011011]
 [-2.1021021 ]
 [-2.09409409]
 [-2.08608609]
 [-2.07807808]
 [-2.07007007]
 [-2.06206206]
 [-2.05405405]
 [-2.04604605]
 [-2.03803804]
 [-2.03003003]
 [-2.02202202]
 [-2.01401401]
 [-2.00600601]
 [-1.997998  ]
 [-1.98998999]
 [-1.98198198]
 [-1.97397397]
 [-1.96596597]
 [-1.95795796]
 [-1.94994995]
 [-1.94194194]
 [-1.93393393]
 [-1.92592593]
 [-1.91791792]
 [-1.90990991]
 [-1.9019019 ]
 [-1.89389389]
 [-1.88588589]
 [-1.87787788]
 [-1.86986987]
 [-1.86186186]
 [-1.85385385]
 [-1.84584585]
 [-1.83783784]
 [-1.82982983]
 [-1.82182182]
 [-1.81381381]
 [-1.80580581]
 [-1.7977978 ]
 [-1.78978979]
 [-1.78178178]
 [-1.77377377]
 [-1.76576577]
 [-1.75775776]
 [-1.74974975]
 [-1.74174174]
 [-1.73373373]
 [-1.72572573]
 [-1.71771772]
 [-1.70970971]
 [-1.7017017 ]
 [-1.69369369]
 [-1.68568569]
 [-1.67767768]
 [-1.66966967]
 [-1.66166166]
 [-1.65365365]
 [-1.64564565]
 [-1.63763764]
 [-1.62962963]
 [-1.62162162]
 [-1.61361361]
 [-1.60560561]
 [-1.5975976 ]
 [-1.58958959]
 [-1.58158158]
 [-1.57357357]
 [-1.56556557]
 [-1.55755756]
 [-1.54954955]
 [-1.54154154]
 [-1.53353353]
 [-1.52552553]
 [-1.51751752]
 [-1.50950951]
 [-1.5015015 ]
 [-1.49349349]
 [-1.48548549]
 [-1.47747748]
 [-1.46946947]
 [-1.46146146]
 [-1.45345345]
 [-1.44544545]
 [-1.43743744]
 [-1.42942943]
 [-1.42142142]
 [-1.41341341]
 [-1.40540541]
 [-1.3973974 ]
 [-1.38938939]
 [-1.38138138]
 [-1.37337337]
 [-1.36536537]
 [-1.35735736]
 [-1.34934935]
 [-1.34134134]
 [-1.33333333]
 [-1.32532533]
 [-1.31731732]
 [-1.30930931]
 [-1.3013013 ]
 [-1.29329329]
 [-1.28528529]
 [-1.27727728]
 [-1.26926927]
 [-1.26126126]
 [-1.25325325]
 [-1.24524525]
 [-1.23723724]
 [-1.22922923]
 [-1.22122122]
 [-1.21321321]
 [-1.20520521]
 [-1.1971972 ]
 [-1.18918919]
 [-1.18118118]
 [-1.17317317]
 [-1.16516517]
 [-1.15715716]
 [-1.14914915]
 [-1.14114114]
 [-1.13313313]
 [-1.12512513]
 [-1.11711712]
 [-1.10910911]
 [-1.1011011 ]
 [-1.09309309]
 [-1.08508509]
 [-1.07707708]
 [-1.06906907]
 [-1.06106106]
 [-1.05305305]
 [-1.04504505]
 [-1.03703704]
 [-1.02902903]
 [-1.02102102]
 [-1.01301301]
 [-1.00500501]
 [-0.996997  ]
 [-0.98898899]
 [-0.98098098]
 [-0.97297297]
 [-0.96496496]
 [-0.95695696]
 [-0.94894895]
 [-0.94094094]
 [-0.93293293]
 [-0.92492492]
 [-0.91691692]
 [-0.90890891]
 [-0.9009009 ]
 [-0.89289289]
 [-0.88488488]
 [-0.87687688]
 [-0.86886887]
 [-0.86086086]
 [-0.85285285]
 [-0.84484484]
 [-0.83683684]
 [-0.82882883]
 [-0.82082082]
 [-0.81281281]
 [-0.8048048 ]
 [-0.7967968 ]
 [-0.78878879]
 [-0.78078078]
 [-0.77277277]
 [-0.76476476]
 [-0.75675676]
 [-0.74874875]
 [-0.74074074]
 [-0.73273273]
 [-0.72472472]
 [-0.71671672]
 [-0.70870871]
 [-0.7007007 ]
 [-0.69269269]
 [-0.68468468]
 [-0.67667668]
 [-0.66866867]
 [-0.66066066]
 [-0.65265265]
 [-0.64464464]
 [-0.63663664]
 [-0.62862863]
 [-0.62062062]
 [-0.61261261]
 [-0.6046046 ]
 [-0.5965966 ]
 [-0.58858859]
 [-0.58058058]
 [-0.57257257]
 [-0.56456456]
 [-0.55655656]
 [-0.54854855]
 [-0.54054054]
 [-0.53253253]
 [-0.52452452]
 [-0.51651652]
 [-0.50850851]
 [-0.5005005 ]
 [-0.49249249]
 [-0.48448448]
 [-0.47647648]
 [-0.46846847]
 [-0.46046046]
 [-0.45245245]
 [-0.44444444]
 [-0.43643644]
 [-0.42842843]
 [-0.42042042]
 [-0.41241241]
 [-0.4044044 ]
 [-0.3963964 ]
 [-0.38838839]
 [-0.38038038]
 [-0.37237237]
 [-0.36436436]
 [-0.35635636]
 [-0.34834835]
 [-0.34034034]
 [-0.33233233]
 [-0.32432432]
 [-0.31631632]
 [-0.30830831]
 [-0.3003003 ]
 [-0.29229229]
 [-0.28428428]
 [-0.27627628]
 [-0.26826827]
 [-0.26026026]
 [-0.25225225]
 [-0.24424424]
 [-0.23623624]
 [-0.22822823]
 [-0.22022022]
 [-0.21221221]
 [-0.2042042 ]
 [-0.1961962 ]
 [-0.18818819]
 [-0.18018018]
 [-0.17217217]
 [-0.16416416]
 [-0.15615616]
 [-0.14814815]
 [-0.14014014]
 [-0.13213213]
 [-0.12412412]
 [-0.11611612]
 [-0.10810811]
 [-0.1001001 ]
 [-0.09209209]
 [-0.08408408]
 [-0.07607608]
 [-0.06806807]
 [-0.06006006]
 [-0.05205205]
 [-0.04404404]
 [-0.03603604]
 [-0.02802803]
 [-0.02002002]
 [-0.01201201]
 [-0.004004  ]
 [ 0.004004  ]
 [ 0.01201201]
 [ 0.02002002]
 [ 0.02802803]
 [ 0.03603604]
 [ 0.04404404]
 [ 0.05205205]
 [ 0.06006006]
 [ 0.06806807]
 [ 0.07607608]
 [ 0.08408408]
 [ 0.09209209]
 [ 0.1001001 ]
 [ 0.10810811]
 [ 0.11611612]
 [ 0.12412412]
 [ 0.13213213]
 [ 0.14014014]
 [ 0.14814815]
 [ 0.15615616]
 [ 0.16416416]
 [ 0.17217217]
 [ 0.18018018]
 [ 0.18818819]
 [ 0.1961962 ]
 [ 0.2042042 ]
 [ 0.21221221]
 [ 0.22022022]
 [ 0.22822823]
 [ 0.23623624]
 [ 0.24424424]
 [ 0.25225225]
 [ 0.26026026]
 [ 0.26826827]
 [ 0.27627628]
 [ 0.28428428]
 [ 0.29229229]
 [ 0.3003003 ]
 [ 0.30830831]
 [ 0.31631632]
 [ 0.32432432]
 [ 0.33233233]
 [ 0.34034034]
 [ 0.34834835]
 [ 0.35635636]
 [ 0.36436436]
 [ 0.37237237]
 [ 0.38038038]
 [ 0.38838839]
 [ 0.3963964 ]
 [ 0.4044044 ]
 [ 0.41241241]
 [ 0.42042042]
 [ 0.42842843]
 [ 0.43643644]
 [ 0.44444444]
 [ 0.45245245]
 [ 0.46046046]
 [ 0.46846847]
 [ 0.47647648]
 [ 0.48448448]
 [ 0.49249249]
 [ 0.5005005 ]
 [ 0.50850851]
 [ 0.51651652]
 [ 0.52452452]
 [ 0.53253253]
 [ 0.54054054]
 [ 0.54854855]
 [ 0.55655656]
 [ 0.56456456]
 [ 0.57257257]
 [ 0.58058058]
 [ 0.58858859]
 [ 0.5965966 ]
 [ 0.6046046 ]
 [ 0.61261261]
 [ 0.62062062]
 [ 0.62862863]
 [ 0.63663664]
 [ 0.64464464]
 [ 0.65265265]
 [ 0.66066066]
 [ 0.66866867]
 [ 0.67667668]
 [ 0.68468468]
 [ 0.69269269]
 [ 0.7007007 ]
 [ 0.70870871]
 [ 0.71671672]
 [ 0.72472472]
 [ 0.73273273]
 [ 0.74074074]
 [ 0.74874875]
 [ 0.75675676]
 [ 0.76476476]
 [ 0.77277277]
 [ 0.78078078]
 [ 0.78878879]
 [ 0.7967968 ]
 [ 0.8048048 ]
 [ 0.81281281]
 [ 0.82082082]
 [ 0.82882883]
 [ 0.83683684]
 [ 0.84484484]
 [ 0.85285285]
 [ 0.86086086]
 [ 0.86886887]
 [ 0.87687688]
 [ 0.88488488]
 [ 0.89289289]
 [ 0.9009009 ]
 [ 0.90890891]
 [ 0.91691692]
 [ 0.92492492]
 [ 0.93293293]
 [ 0.94094094]
 [ 0.94894895]
 [ 0.95695696]
 [ 0.96496496]
 [ 0.97297297]
 [ 0.98098098]
 [ 0.98898899]
 [ 0.996997  ]
 [ 1.00500501]
 [ 1.01301301]
 [ 1.02102102]
 [ 1.02902903]
 [ 1.03703704]
 [ 1.04504505]
 [ 1.05305305]
 [ 1.06106106]
 [ 1.06906907]
 [ 1.07707708]
 [ 1.08508509]
 [ 1.09309309]
 [ 1.1011011 ]
 [ 1.10910911]
 [ 1.11711712]
 [ 1.12512513]
 [ 1.13313313]
 [ 1.14114114]
 [ 1.14914915]
 [ 1.15715716]
 [ 1.16516517]
 [ 1.17317317]
 [ 1.18118118]
 [ 1.18918919]
 [ 1.1971972 ]
 [ 1.20520521]
 [ 1.21321321]
 [ 1.22122122]
 [ 1.22922923]
 [ 1.23723724]
 [ 1.24524525]
 [ 1.25325325]
 [ 1.26126126]
 [ 1.26926927]
 [ 1.27727728]
 [ 1.28528529]
 [ 1.29329329]
 [ 1.3013013 ]
 [ 1.30930931]
 [ 1.31731732]
 [ 1.32532533]
 [ 1.33333333]
 [ 1.34134134]
 [ 1.34934935]
 [ 1.35735736]
 [ 1.36536537]
 [ 1.37337337]
 [ 1.38138138]
 [ 1.38938939]
 [ 1.3973974 ]
 [ 1.40540541]
 [ 1.41341341]
 [ 1.42142142]
 [ 1.42942943]
 [ 1.43743744]
 [ 1.44544545]
 [ 1.45345345]
 [ 1.46146146]
 [ 1.46946947]
 [ 1.47747748]
 [ 1.48548549]
 [ 1.49349349]
 [ 1.5015015 ]
 [ 1.50950951]
 [ 1.51751752]
 [ 1.52552553]
 [ 1.53353353]
 [ 1.54154154]
 [ 1.54954955]
 [ 1.55755756]
 [ 1.56556557]
 [ 1.57357357]
 [ 1.58158158]
 [ 1.58958959]
 [ 1.5975976 ]
 [ 1.60560561]
 [ 1.61361361]
 [ 1.62162162]
 [ 1.62962963]
 [ 1.63763764]
 [ 1.64564565]
 [ 1.65365365]
 [ 1.66166166]
 [ 1.66966967]
 [ 1.67767768]
 [ 1.68568569]
 [ 1.69369369]
 [ 1.7017017 ]
 [ 1.70970971]
 [ 1.71771772]
 [ 1.72572573]
 [ 1.73373373]
 [ 1.74174174]
 [ 1.74974975]
 [ 1.75775776]
 [ 1.76576577]
 [ 1.77377377]
 [ 1.78178178]
 [ 1.78978979]
 [ 1.7977978 ]
 [ 1.80580581]
 [ 1.81381381]
 [ 1.82182182]
 [ 1.82982983]
 [ 1.83783784]
 [ 1.84584585]
 [ 1.85385385]
 [ 1.86186186]
 [ 1.86986987]
 [ 1.87787788]
 [ 1.88588589]
 [ 1.89389389]
 [ 1.9019019 ]
 [ 1.90990991]
 [ 1.91791792]
 [ 1.92592593]
 [ 1.93393393]
 [ 1.94194194]
 [ 1.94994995]
 [ 1.95795796]
 [ 1.96596597]
 [ 1.97397397]
 [ 1.98198198]
 [ 1.98998999]
 [ 1.997998  ]
 [ 2.00600601]
 [ 2.01401401]
 [ 2.02202202]
 [ 2.03003003]
 [ 2.03803804]
 [ 2.04604605]
 [ 2.05405405]
 [ 2.06206206]
 [ 2.07007007]
 [ 2.07807808]
 [ 2.08608609]
 [ 2.09409409]
 [ 2.1021021 ]
 [ 2.11011011]
 [ 2.11811812]
 [ 2.12612613]
 [ 2.13413413]
 [ 2.14214214]
 [ 2.15015015]
 [ 2.15815816]
 [ 2.16616617]
 [ 2.17417417]
 [ 2.18218218]
 [ 2.19019019]
 [ 2.1981982 ]
 [ 2.20620621]
 [ 2.21421421]
 [ 2.22222222]
 [ 2.23023023]
 [ 2.23823824]
 [ 2.24624625]
 [ 2.25425425]
 [ 2.26226226]
 [ 2.27027027]
 [ 2.27827828]
 [ 2.28628629]
 [ 2.29429429]
 [ 2.3023023 ]
 [ 2.31031031]
 [ 2.31831832]
 [ 2.32632633]
 [ 2.33433433]
 [ 2.34234234]
 [ 2.35035035]
 [ 2.35835836]
 [ 2.36636637]
 [ 2.37437437]
 [ 2.38238238]
 [ 2.39039039]
 [ 2.3983984 ]
 [ 2.40640641]
 [ 2.41441441]
 [ 2.42242242]
 [ 2.43043043]
 [ 2.43843844]
 [ 2.44644645]
 [ 2.45445445]
 [ 2.46246246]
 [ 2.47047047]
 [ 2.47847848]
 [ 2.48648649]
 [ 2.49449449]
 [ 2.5025025 ]
 [ 2.51051051]
 [ 2.51851852]
 [ 2.52652653]
 [ 2.53453453]
 [ 2.54254254]
 [ 2.55055055]
 [ 2.55855856]
 [ 2.56656657]
 [ 2.57457457]
 [ 2.58258258]
 [ 2.59059059]
 [ 2.5985986 ]
 [ 2.60660661]
 [ 2.61461461]
 [ 2.62262262]
 [ 2.63063063]
 [ 2.63863864]
 [ 2.64664665]
 [ 2.65465465]
 [ 2.66266266]
 [ 2.67067067]
 [ 2.67867868]
 [ 2.68668669]
 [ 2.69469469]
 [ 2.7027027 ]
 [ 2.71071071]
 [ 2.71871872]
 [ 2.72672673]
 [ 2.73473473]
 [ 2.74274274]
 [ 2.75075075]
 [ 2.75875876]
 [ 2.76676677]
 [ 2.77477477]
 [ 2.78278278]
 [ 2.79079079]
 [ 2.7987988 ]
 [ 2.80680681]
 [ 2.81481481]
 [ 2.82282282]
 [ 2.83083083]
 [ 2.83883884]
 [ 2.84684685]
 [ 2.85485485]
 [ 2.86286286]
 [ 2.87087087]
 [ 2.87887888]
 [ 2.88688689]
 [ 2.89489489]
 [ 2.9029029 ]
 [ 2.91091091]
 [ 2.91891892]
 [ 2.92692693]
 [ 2.93493493]
 [ 2.94294294]
 [ 2.95095095]
 [ 2.95895896]
 [ 2.96696697]
 [ 2.97497497]
 [ 2.98298298]
 [ 2.99099099]
 [ 2.998999  ]
 [ 3.00700701]
 [ 3.01501502]
 [ 3.02302302]
 [ 3.03103103]
 [ 3.03903904]
 [ 3.04704705]
 [ 3.05505506]
 [ 3.06306306]
 [ 3.07107107]
 [ 3.07907908]
 [ 3.08708709]
 [ 3.0950951 ]
 [ 3.1031031 ]
 [ 3.11111111]
 [ 3.11911912]
 [ 3.12712713]
 [ 3.13513514]
 [ 3.14314314]
 [ 3.15115115]
 [ 3.15915916]
 [ 3.16716717]
 [ 3.17517518]
 [ 3.18318318]
 [ 3.19119119]
 [ 3.1991992 ]
 [ 3.20720721]
 [ 3.21521522]
 [ 3.22322322]
 [ 3.23123123]
 [ 3.23923924]
 [ 3.24724725]
 [ 3.25525526]
 [ 3.26326326]
 [ 3.27127127]
 [ 3.27927928]
 [ 3.28728729]
 [ 3.2952953 ]
 [ 3.3033033 ]
 [ 3.31131131]
 [ 3.31931932]
 [ 3.32732733]
 [ 3.33533534]
 [ 3.34334334]
 [ 3.35135135]
 [ 3.35935936]
 [ 3.36736737]
 [ 3.37537538]
 [ 3.38338338]
 [ 3.39139139]
 [ 3.3993994 ]
 [ 3.40740741]
 [ 3.41541542]
 [ 3.42342342]
 [ 3.43143143]
 [ 3.43943944]
 [ 3.44744745]
 [ 3.45545546]
 [ 3.46346346]
 [ 3.47147147]
 [ 3.47947948]
 [ 3.48748749]
 [ 3.4954955 ]
 [ 3.5035035 ]
 [ 3.51151151]
 [ 3.51951952]
 [ 3.52752753]
 [ 3.53553554]
 [ 3.54354354]
 [ 3.55155155]
 [ 3.55955956]
 [ 3.56756757]
 [ 3.57557558]
 [ 3.58358358]
 [ 3.59159159]
 [ 3.5995996 ]
 [ 3.60760761]
 [ 3.61561562]
 [ 3.62362362]
 [ 3.63163163]
 [ 3.63963964]
 [ 3.64764765]
 [ 3.65565566]
 [ 3.66366366]
 [ 3.67167167]
 [ 3.67967968]
 [ 3.68768769]
 [ 3.6956957 ]
 [ 3.7037037 ]
 [ 3.71171171]
 [ 3.71971972]
 [ 3.72772773]
 [ 3.73573574]
 [ 3.74374374]
 [ 3.75175175]
 [ 3.75975976]
 [ 3.76776777]
 [ 3.77577578]
 [ 3.78378378]
 [ 3.79179179]
 [ 3.7997998 ]
 [ 3.80780781]
 [ 3.81581582]
 [ 3.82382382]
 [ 3.83183183]
 [ 3.83983984]
 [ 3.84784785]
 [ 3.85585586]
 [ 3.86386386]
 [ 3.87187187]
 [ 3.87987988]
 [ 3.88788789]
 [ 3.8958959 ]
 [ 3.9039039 ]
 [ 3.91191191]
 [ 3.91991992]
 [ 3.92792793]
 [ 3.93593594]
 [ 3.94394394]
 [ 3.95195195]
 [ 3.95995996]
 [ 3.96796797]
 [ 3.97597598]
 [ 3.98398398]
 [ 3.99199199]
 [ 4.        ]]
[[3.60980674 0.41652999]
 [3.60980674 0.41652999]
 [3.60980674 0.41652999]
 ...
 [2.4296138  0.43140993]
 [2.4296138  0.43140993]
 [2.4296138  0.43140993]]

Process finished with exit code 0

 

10.5.Iris_RandomForest_Enum.py

#!/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.ensemble import RandomForestClassifier


def iris_type(s):
    it = {b'Iris-setosa': 0, b'Iris-versicolor': 1, b'Iris-virginica': 2}
    return it[s]

# 'sepal length', 'sepal width', 'petal length', 'petal width'
iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'

if __name__ == "__main__":
    mpl.rcParams['font.sans-serif'] = [u'SimHei']  # 黑体 FangSong/KaiTi
    mpl.rcParams['axes.unicode_minus'] = False

    path = '..\\8.Regression\\8.iris.data'  # 数据文件路径
    data = np.loadtxt(path, dtype=float, delimiter=',', converters={4: iris_type})
    x_prime, y = np.split(data, (4,), axis=1)

    feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
    plt.figure(figsize=(10, 9), facecolor='#FFFFFF')
    for i, pair in enumerate(feature_pairs):
        # 准备数据
        x = x_prime[:, pair]

        # 随机森林
        clf = RandomForestClassifier(n_estimators=200, criterion='entropy', max_depth=4)
        rf_clf = clf.fit(x, y.ravel())

        # 画图
        N, M = 500, 500  # 横纵各采样多少个值
        x1_min, x1_max = x[:, 0].min(), x[:, 0].max()  # 第0列的范围
        x2_min, x2_max = x[:, 1].min(), x[:, 1].max()  # 第1列的范围
        t1 = np.linspace(x1_min, x1_max, N)
        t2 = np.linspace(x2_min, x2_max, M)
        x1, x2 = np.meshgrid(t1, t2)  # 生成网格采样点
        x_test = np.stack((x1.flat, x2.flat), axis=1)  # 测试点

        # 训练集上的预测结果
        y_hat = rf_clf.predict(x)
        y = y.reshape(-1)
        c = np.count_nonzero(y_hat == y)    # 统计预测正确的个数
        print('特征:  ', iris_feature[pair[0]], ' + ', iris_feature[pair[1]],)
        print('\t预测正确数目:', c,)
        print('\t准确率: %.2f%%' % (100 * float(c) / float(len(y))))

        # 显示
        cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
        cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
        y_hat = rf_clf.predict(x_test)  # 预测值
        y_hat = y_hat.reshape(x1.shape)  # 使之与输入的形状相同
        plt.subplot(2, 3, i+1)
        plt.pcolormesh(x1, x2, y_hat, cmap=cm_light)  # 预测值
        plt.scatter(x[:, 0], x[:, 1], c=y, edgecolors='k', cmap=cm_dark)  # 样本
        plt.xlabel(iris_feature[pair[0]], fontsize=14)
        plt.ylabel(iris_feature[pair[1]], fontsize=14)
        plt.xlim(x1_min, x1_max)
        plt.ylim(x2_min, x2_max)
        plt.grid()
    plt.tight_layout(2.5)
    plt.subplots_adjust(top=0.92)
    plt.suptitle(u'随机森林对鸢尾花数据的两特征组合的分类结果', fontsize=18)
    plt.show()

 结果:

E:\pythonwork\venv\Scripts\python.exe E:/pythonspace/10.RandomForest/10.5.Iris_RandomForest_Enum.py
E:\pythonwork\venv\lib\site-packages\sklearn\ensemble\weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.
  from numpy.core.umath_tests import inner1d
特征:   花萼长度  +  花萼宽度
	预测正确数目: 125
	准确率: 83.33%
特征:   花萼长度  +  花瓣长度
	预测正确数目: 145
	准确率: 96.67%
特征:   花萼长度  +  花瓣宽度
	预测正确数目: 146
	准确率: 97.33%
特征:   花萼宽度  +  花瓣长度
	预测正确数目: 144
	准确率: 96.00%
特征:   花萼宽度  +  花瓣宽度
	预测正确数目: 145
	准确率: 96.67%
特征:   花瓣长度  +  花瓣宽度
	预测正确数目: 146
	准确率: 97.33%

Process finished with exit code 0

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值