机器学习-树回归

在前面文章机器学习-回归中,我们讨论了一般的线性回归,这里面有一些强大的方法,而且也非常实用。但这些方法有一些不足,

  1. 需要拟合所有的样本点(局部加权线性回归除外),计算量较大
  2. 现实生活中很多问题是非线性的,不能使用线性模型

这篇文章会介绍一种非线性回归模型-树回归,通过CART(Classification And Regression Tree,分类回归树)来构建模型算法。该算法可用于分类(找出类别),也可用于回归(预测)。CART有可能出现过拟合的问题,通过树剪枝技术,可以解决此问题。

本章包含以下内容:

  • 回归树
  • 树剪枝
  • 模型树
  • 分类和预测
  • 小结

部分内容引用自《Machine Learning in Action》


回归树

回归树一般用于数据的分类,其基本思想是将原数据集按照某些属性值进行二拆分,有点类似构建二叉查找树的过程。树的叶子节点可能是一个数据点的值,也可能是一些数据点的平均值,这个可通过参数控制。

机器学习-决策树一文中,我们通过信息增益来寻找最佳拆分。而这里,我们通过最小平方和的方式来寻找最佳拆分。例如,

假设原数据集为S,按照某个属性值a可以将S拆分为S1S2,然后我们计算SS1S2的样本方差,如果满足:

var(S1) * N(S1) + var(S2) * N(S2) < var(S) * N(S),其中,var(S1)为S1的样本方差,N(S1)表示S1中样本点的个数,S2和S类似,

则说明按照属性a拆分能够降低原数据集S的离散程度。对S中的所有属性值进行上面的计算,我们可以找到一个最佳的a,该属性值就是欲拆分的值。

注意:方差乘以样本点个数,就是样本值与平均值的差的平方和。

下面通过代码来实现回归树。创建模块reg_tree.py,并输入以下代码:

import numpy as np


def load_data_set(file_name):
    data_mat = []
    with open(file_name) as f:
        for line in f.readlines():
            current_line = line.strip().split('\t')
            float_line = list(map(float, current_line))
            data_mat.append(float_line)
    return data_mat


def bin_split_data_set(data_set, feature, value):
    mat0 = data_set[np.nonzero(data_set[:, feature] > value)[0], :]
    mat1 = data_set[np.nonzero(data_set[:, feature] <= value)[0], :]
    return mat0, mat1


def reg_leaf(data_set):
    return np.mean(data_set[:, -1])


def reg_err(data_set):
    return np.var(data_set[:, -1]) * np.shape(data_set)[0]


def choose_best_split(data_set, leaf_type=reg_leaf, err_type=reg_err, ops=(1, 4)):
    tol_S = ops[0];
    tol_N = ops[1]
    if len(set(data_set[:, -1].T.tolist()[0])) == 1:  # exit cond 1
        return None, leaf_type(data_set)
    m, n = np.shape(data_set)
    S = err_type(data_set)
    best_S = np.inf;
    best_index = 0;
    best_value = 0
    for feat_index in range(n - 1):
        for split_val in set(data_set[:, feat_index].T.tolist()[0]):
            mat0, mat1 = bin_split_data_set(data_set, feat_index, split_val)
            if (np.shape(mat0)[0] < tol_N) or (np.shape(mat1)[0] < tol_N):
                continue
            new_S = err_type(mat0) + err_type(mat1)
            if new_S < best_S:
                best_index = feat_index
                best_value = split_val
                best_S = new_S
    if (S - best_S) < tol_S:
        return None, leaf_type(data_set)
    mat0, mat1 = bin_split_data_set(data_set, best_index, best_value)
    if (np.shape(mat0)[0] < tol_N) or (np.shape(mat1)[0] < tol_N):
        return None, leaf_type(data_set)
    return best_index, best_value


def create_tree(data_set, leaf_type=reg_leaf, err_type=reg_err,
                ops=(1, 4)):
    feat, val = choose_best_split(data_set, leaf_type, err_type, ops)
    if feat is None:
        return val
    ret_tree = {}
    ret_tree['spInd'] = feat
    ret_tree['spVal'] = val
    lSet, rSet = bin_split_data_set(data_set, feat, val)
    ret_tree['left'] = create_tree(lSet, leaf_type, err_type, ops)
    ret_tree['right'] = create_tree(rSet, leaf_type, err_type, ops)
    return ret_tree


def get_reg_tree_values(tree):
    result = []
    sp_val = tree['spVal']
    left = tree['left']
    right = tree['right']
    result.append(sp_val)
    if type(left) is dict:
        result.extend(get_reg_tree_values(left))
    if type(right) is dict:
        result.extend(get_reg_tree_values(right))
    return result


def get_reg_tree_leaf_values(tree):
    result = []
    left = tree['left']
    right = tree['right']
    if type(left) is dict:
        result.extend(get_reg_tree_leaf_values(left))
    else:
        result.append(left)
    if type(right) is dict:
        result.extend(get_reg_tree_leaf_values(right))
    else:
        result.append(right)
    return result


def get_model_tree_values(tree):
    result = []
    left = tree['left']
    right = tree['right']
    if type(left) is dict:
        result.extend(get_model_tree_values(left))
    else:
        left_data = np.array(left)
        result.append([left_data[0][0], left_data[1][0]])
    if type(right) is dict:
        result.extend(get_model_tree_values(right))
    else:
        right_data = np.array(right)
        result.append([right_data[0][0], right_data[1][0]])
    return result


if __name__ == '__main__':
    data = load_data_set('ex00.txt')
    mat = np.mat(data)
    tree = create_tree(mat)
    print(tree)

运行结果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/reg_tree.py
{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}

Process finished with exit code 0

可以看出,原数据集被属性值(属性ID:0,值:0.48813)拆分成了两半,左子树的平均值为:1.0180967672413792,右子树的平均值为:-0.04465028571428572。

下面我们通过画图来更直观的显示这些关系,创建模块reg_tree_plot.py,并输入以下代码:

import numpy as np
import matplotlib.pyplot as plt

import tree_regression.reg_tree as reg_tree


def test_dataset1():
    data = reg_tree.load_data_set('ex00.txt')
    mat = np.mat(data)
    tree = reg_tree.create_tree(mat)
    print(tree)

    x = np.array(data)[:, 0]
    y = np.array(data)[:, 1]
    plt.plot(x, y, 'o', label='Original Values')

    line_x_arr = reg_tree.get_reg_tree_values(tree)
    plot_lines(line_x_arr, np.min(y), np.max(y))

    plot_mean_points(tree, np.min(x), np.max(x))

    plt.title('Regression Tree')
    plt.legend()

    plt.show()


def test_dataset2():
    data = reg_tree.load_data_set('ex0.txt')
    mat = np.mat(data)
    tree = reg_tree.create_tree(mat)
    # tree = reg_tree.create_tree(mat, ops=(0.1, 10))
    print(tree)

    x = np.array(data)[:, 1]
    y = np.array(data)[:, 2]
    plt.plot(x, y, 'o', label='Original Values')

    line_x_arr = reg_tree.get_reg_tree_values(tree)
    plot_lines(line_x_arr, np.min(y), np.max(y))

    plot_mean_points(tree, np.min(x), np.max(x))

    plt.title('Regression Tree')
    plt.legend()

    plt.show()


def plot_lines(x_arr, y_min, y_max):
    for x in x_arr:
        line_x = (x, x)
        line_y = (y_min, y_max)
        plt.plot(line_x, line_y, label='Split Line')


def plot_mean_points(tree, min_x, max_x):
    line_x_arr = reg_tree.get_reg_tree_values(tree)
    mean_y_values = reg_tree.get_reg_tree_leaf_values(tree)
    mean_y_values.sort()
    print(mean_y_values)
    mean_x_values = []
    tmp_x = [min_x]
    tmp_x.extend(line_x_arr)
    tmp_x.append(max_x)
    tmp_x.sort()
    print(tmp_x)
    index = 0
    while index < len(tmp_x) - 1:
        mean_x_values.append((tmp_x[index] + tmp_x[index + 1]) / 2)
        index += 1
    plt.plot(mean_x_values, mean_y_values, 'or', label='Mean Values')


if __name__ == '__main__':
    test_dataset1()
    #test_dataset2()

运行结果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/reg_tree_plot.py
{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}
[-0.04465028571428572, 1.0180967672413792]
[0.000234, 0.48813, 0.996757]

图像:

注意,上面两个红色的点,分别表示不同分类的平均值,中间的线是用于拆分的分割线。

修改上面代码,我们测试dataset2,看看一个更复杂的场景:

if __name__ == '__main__':
    #test_dataset1()
    test_dataset2()

运行结果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/reg_tree_plot.py
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 1, 'spVal': 0.197834, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}
[-0.023838155555555553, 1.0289583666666666, 1.980035071428571, 2.9836209534883724, 3.9871632]
[0.004327, 0.197834, 0.39435, 0.582002, 0.797583, 0.998709]

图像:

该数据集被拆分出了五个子集,有五个叶子节点。

上面说了,通过修改参数可以控制叶子节点的个数(叶子节点是一个数据点的值,或者是一些数据点的平均值),修改上面函数test_dataset2()的代码增加参数 ops=(0.1, 10)

def test_dataset2():
    data = reg_tree.load_data_set('ex0.txt')
    mat = np.mat(data)
    # tree = reg_tree.create_tree(mat)
    tree = reg_tree.create_tree(mat, ops=(0.1, 10))
    print(tree)

    x = np.array(data)[:, 1]
    y = np.array(data)[:, 2]
    plt.plot(x, y, 'o', label='Original Values')

    line_x_arr = reg_tree.get_reg_tree_values(tree)
    plot_lines(line_x_arr, np.min(y), np.max(y))

    plot_mean_points(tree, np.min(x), np.max(x))

    plt.title('Regression Tree')
    plt.legend()

    plt.show()

运行结果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/reg_tree_plot.py
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': {'spInd': 1, 'spVal': 0.486698, 'left': 2.0409245, 'right': 1.8810897500000001}}, 'right': {'spInd': 1, 'spVal': 0.197834, 'left': {'spInd': 1, 'spVal': 0.316465, 'left': 0.9437193846153846, 'right': 1.094141117647059}, 'right': {'spInd': 1, 'spVal': 0.148654, 'left': 0.07189454545454545, 'right': -0.054810500000000005}}}
[-0.054810500000000005, 0.07189454545454545, 0.9437193846153846, 1.094141117647059, 1.8810897500000001, 2.0409245, 2.9836209534883724, 3.9871632]
[0.004327, 0.148654, 0.197834, 0.316465, 0.39435, 0.486698, 0.582002, 0.797583, 0.998709]

图像:

可以看出,此时划分出了更多的数据集,构造出的树也更复杂,能够更细粒度的进行分类。

树回归模型对上面的两个参数非常敏感,下面进行说明:

  1. 第一个参数用于控制误差大小。对给定的数据集,如果按照最好的属性值拆分后,原数据集的误差(方差乘以样本点个数)与拆分后两个子集的误差的差小于该误差参数,则说明即便按最好的属性来拆分,我们也无法更好的减少整体误差,因此原数据集就没必要进行拆分了,可以作为一个整体当做叶子节点。
  2. 第二个参数用于控制叶子节点至少需要包含的样本点个数。对给定的数据集,如果按照最好的属性值拆分后,左子树或右子树的样本点个数小于该参数,则表示叶子节点的样本点个数太少,因此原数据集不能进行拆分了。

通过引入上面两个参数,我们就能在构建树的过程中控制回归树的大小,这种方法称为树剪枝的预剪枝方法。

树剪枝

上一小节我们提到了预剪枝方法,该方法可以在构建回归树的过程中进行剪枝。还有一种称为后剪枝的方法,该方法基于一棵已经构建好的回归树和一个测试集。这一小节我们讨论后剪枝方法。

后剪枝方法的基本思想是先遍历得到所有叶子节点,再分别计算合并两个叶子节点前的测试集误差和合并两个叶子节点后的测试集误差,如果合并后的误差小于合并前的误差,则说明可以合并两个叶子节点,否则不合并。

下面我们通过代码来演示后剪枝。创建模块tree_pruning.py,并输入以下代码:

import numpy as np
import tree_regression.reg_tree as rt


def is_tree(obj):
    return type(obj) is dict


def get_mean(tree):
    if is_tree(tree['right']):
        tree['right'] = get_mean(tree['right'])
    if is_tree(tree['left']):
        tree['left'] = get_mean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0


def prune(tree, test_data):
    if np.shape(test_data)[0] == 0:
        return get_mean(tree)  # if we have no test data collapse the tree
    if is_tree(tree['right']) or is_tree(tree['left']):  # if the branches are not trees try to prune them
        l_set, r_set = rt.bin_split_data_set(test_data, tree['spInd'], tree['spVal'])
    if is_tree(tree['left']):
        tree['left'] = prune(tree['left'], l_set)
    if is_tree(tree['right']):
        tree['right'] = prune(tree['right'], r_set)
    # if they are now both leafs, see if we can merge them
    if not is_tree(tree['left']) and not is_tree(tree['right']):
        l_set, r_set = rt.bin_split_data_set(test_data, tree['spInd'], tree['spVal'])
        error_no_merge = np.sum(np.power(l_set[:, -1] - tree['left'], 2)) + np.sum(
            np.power(r_set[:, -1] - tree['right'], 2))
        tree_mean = (tree['left'] + tree['right']) / 2.0
        error_merge = np.sum(np.power(test_data[:, -1] - tree_mean, 2))
        if error_merge < error_no_merge:
            print("merging...")
            return tree_mean
        else:
            return tree
    else:
        return tree


def test_pruning():
    print("Before pruning:")
    data = rt.load_data_set('ex2.txt')
    mat = np.mat(data)
    tree = rt.create_tree(mat)
    print(tree)
    print("After pruning:")
    test_mat = np.mat(rt.load_data_set('ex2test.txt'))
    prune_tree = prune(tree, test_mat)
    print(prune_tree)


if __name__ == '__main__':
    test_pruning()

运行结果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/tree_pruning.py
Before pruning:
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': 105.24862350000001, 'right': 112.42895575000001}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': 87.3103875, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.910975, 'left': 96.452867, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': 104.825409, 'right': {'spInd': 0, 'spVal': 0.872883, 'left': 95.181793, 'right': 102.25234449999999}}}, 'right': 95.27584316666666}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 81.110152, 'right': 88.78449880000001}}, 'right': 102.35780185714285}, 'right': 78.08564325}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.666452, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': 114.554706, 'right': {'spInd': 0, 'spVal': 0.698472, 'left': 104.82495374999999, 'right': 108.92921799999999}}, 'right': 114.1516242857143}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': 93.67344971428572, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': 123.2101316, 'right': {'spInd': 0, 'spVal': 0.553797, 'left': 97.20018024999999, 'right': {'spInd': 0, 'spVal': 0.51915, 'left': {'spInd': 0, 'spVal': 0.543843, 'left': 109.38961049999999, 'right': 110.979946}, 'right': 101.73699325000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': {'spInd': 0, 'spVal': 0.467383, 'left': 12.50675925, 'right': 3.4331330000000007}, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': -12.558604833333334, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': 14.38417875, 'right': {'spInd': 0, 'spVal': 0.385021, 'left': -0.8923554999999995, 'right': 3.6584772500000016}}}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.350725, 'left': -15.08511175, 'right': -22.693879600000002}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': 15.05929075, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -19.9941552, 'right': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.202161, 'left': {'spInd': 0, 'spVal': 0.217214, 'left': {'spInd': 0, 'spVal': 0.228473, 'left': {'spInd': 0, 'spVal': 0.25807, 'left': 0.40377471428571476, 'right': -13.070501}, 'right': 6.770429}, 'right': -11.822278500000001}, 'right': 3.4496025}, 'right': {'spInd': 0, 'spVal': 0.156067, 'left': -12.1079725, 'right': -6.247900000000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 6.509843285714284, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': -2.544392714285715, 'right': 4.091626}}}}}
After pruning:
merging...
merging...
merging...
merging...
merging...
merging...
merging...
merging...
merging...
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': 105.24862350000001, 'right': 112.42895575000001}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': 87.3103875, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.910975, 'left': 96.452867, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': 104.825409, 'right': {'spInd': 0, 'spVal': 0.872883, 'left': 95.181793, 'right': 102.25234449999999}}}, 'right': 95.27584316666666}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 81.110152, 'right': 88.78449880000001}}, 'right': 102.35780185714285}, 'right': 78.08564325}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.666452, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': 114.554706, 'right': 106.87708587499999}, 'right': 114.1516242857143}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': 93.67344971428572, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': 123.2101316, 'right': 101.580533}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': 7.969946125, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': -12.558604833333334, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': 14.38417875, 'right': 1.383060875000001}}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.350725, 'left': -15.08511175, 'right': -22.693879600000002}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': 15.05929075, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -19.9941552, 'right': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.202161, 'left': -5.801872785714286, 'right': 3.4496025}, 'right': {'spInd': 0, 'spVal': 0.156067, 'left': -12.1079725, 'right': -6.247900000000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 6.509843285714284, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': -2.544392714285715, 'right': 4.091626}}}}}

Process finished with exit code 0

可以看出,上面一共减掉了9对叶子节点,能够在一定程度上减少回归树的大小。

模型树

模型树和回归树非常类似,有两个不同的地方:

  1. 回归树的叶子节点是一个数据点的值或一些数据点的平均值,而模型树的叶子节点是一些数据点的线性回归参数
  2. 计算误差时,回归树基于平均值,而模型树基于线性回归的预测值

可以这样理解,分类完成后,回归树用子类的平均值来近似代替子类中的所有样本点,而模型树则使用回归方程来计算子类中的样本点。可以看出,回归树可以更好的用于分类,而模型树可以更好的用于预测。

下面我们通过代码来演示模型树。创建模块model_tree.py,并输入以下代码:

import numpy as np
import matplotlib.pyplot as plt
import tree_regression.reg_tree as rt


def linear_solve(data_set):  # helper function used in two places
    m, n = np.shape(data_set)
    X = np.mat(np.ones((m, n)));
    Y = np.mat(np.ones((m, 1)))  # create a copy of data with 1 in 0th postion
    X[:, 1:n] = data_set[:, 0:n - 1];
    Y = data_set[:, -1]  # and strip out Y
    xTx = X.T * X
    if np.linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse, try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y


def model_leaf(data_set):
    ws, X, Y = linear_solve(data_set)
    return ws


def model_err(data_set):
    ws, X, Y = linear_solve(data_set)
    y_hat = X * ws
    return sum(np.power(Y - y_hat, 2))


def cal_line_y(line_value, x):
    return line_value[0] + line_value[1] * x


if __name__ == "__main__":
    my_mat = np.mat(rt.load_data_set('exp2.txt'))
    tree = rt.create_tree(my_mat, model_leaf, model_err)
    print(tree)

    line_values = rt.get_model_tree_values(tree)
    print(line_values)

    for line_value in line_values:
        x = np.linspace(-0.1, 1.2)
        y = cal_line_y(line_value, x)
        plt.plot(x, y, label="Regression Line: f(x)=%f * x + %f" % (line_value[1], line_value[0]))

    x = np.array(my_mat)[:, 0]
    y = np.array(my_mat)[:, 1]
    plt.plot(x, y, 'o', label='Original Values')

    plt.title('Model Tree')
    plt.legend()
    plt.show()

运行结果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/model_tree.py
{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],
        [1.19647739e+01]]), 'right': matrix([[3.46877936],
        [1.18521743]])}
[[0.0016985569360628006, 11.964773944277027], [3.4687793552577872, 1.1852174309188115]]

图像:

可以看出,该模型树有两个叶子节点,分别代表两个不同的回归方程,可以分段预测不同的样本点。

分类和预测

通过遍历回归树和模型树我们可以判断某个样本点属于哪个分类或者其回归值是多少。创建模块forecast.py,并输入以下代码:

import numpy as np
import tree_regression.tree_pruning as tp
import tree_regression.reg_tree as rt
import tree_regression.model_tree as mt


def reg_tree_eval(model, in_dat):
    return float(model)


def model_tree_eval(model, in_dat):
    n = np.shape(in_dat)[1]
    X = np.mat(np.ones((1, n + 1)))
    X[:, 1:n + 1] = in_dat
    return float(X * model)


def tree_forecast(tree, in_data, model_eval=reg_tree_eval):
    if not tp.is_tree(tree):
        return model_eval(tree, in_data)
    if in_data[tree['spInd']] > tree['spVal']:
        if tp.is_tree(tree['left']):
            return tree_forecast(tree['left'], in_data, model_eval)
        else:
            return model_eval(tree['left'], in_data)
    else:
        if tp.is_tree(tree['right']):
            return tree_forecast(tree['right'], in_data, model_eval)
        else:
            return model_eval(tree['right'], in_data)


def create_forecast(tree, test_data, model_eval=reg_tree_eval):
    m = len(test_data)
    y_hat = np.mat(np.zeros((m, 1)))
    for i in range(m):
        y_hat[i, 0] = tree_forecast(tree, np.mat(test_data[i]), model_eval)
    return y_hat


def test_reg_tree():
    print("Test regression tree:")
    data = rt.load_data_set('ex0.txt')
    tree = rt.create_tree(np.mat(data))
    print(tree)
    in_dat = (1.000000, 0.558918)
    y = 1.719148
    y_hat = tree_forecast(tree, in_dat)
    print("Real Y: %f" % y)
    print("Hat Y: %f" % y_hat)


def test_model_tree():
    print("Test model tree:")
    data = rt.load_data_set('exp2.txt')
    tree = rt.create_tree(np.mat(data), mt.model_leaf, mt.model_err)
    print(tree)
    in_dat = np.array([(0.010767,)])
    y = 3.565835
    y_hat = tree_forecast(tree, in_dat, model_eval=model_tree_eval)
    print("Real Y: %f" % y)
    print("Hat Y: %f" % y_hat)


if __name__ == '__main__':
    test_reg_tree()
    test_model_tree()

输出:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/forecast.py
Test regression tree:
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 1, 'spVal': 0.197834, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}
Real Y: 1.719148
Hat Y: 1.980035
Test model tree:
{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],
        [1.19647739e+01]]), 'right': matrix([[3.46877936],
        [1.18521743]])}
Real Y: 3.565835
Hat Y: 3.481541

Process finished with exit code 0

小结

数据集中经常包含一些复杂的相互关系,使得数据和目标变量间呈现非线性关系。对这些复杂的数据集建模,一种有效的方式是使用树来对预测值分段,包括分段常数和分段直线。分段常数是数据子集的平均值,一般用于分类,对应于回归树。分段直线是数据子集的回归方程,一般用于预测,对应于模型树。

CART采用二元树来拆分数据集,如果出现过拟合问题,可以通过树剪枝的技术来减掉多余的叶子节点。树剪枝分为预剪枝和后剪枝,预剪枝是在构建树的过程中减掉多余的叶子节点,需要用户指定两个参数,而后剪枝基于一棵已经构建好的回归树或模型树以及一个测试集。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值