在学习集成方法的过程中,顺着思路来到CART回归树,它作为GBDT的基学习器,是以均方误差作为损失函数,找到其取极小值时的点作为切分点,将数据集划分为左右子树,然后继续上面的步骤。
下面是代码部分,由于《机器学习实战》书中的代码存在部分错误,下面给予修正。
# _*_ coding: UTF-8 _*_ from numpy import * import numpy as np import pickle import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties # 设置字体属性 def loadDataSet(fileName): ''' 读取一个一tab键为分隔符的文件,然后将每行的内容保存成一组浮点数 ''' dataMat = [] fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split('\t') fltLine = list(map(float,curLine)) dataMat.append(fltLine) return dataMat def binSplitDataSet(dataSet, feature, value): ''' 数据集切分函数 ''' mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:] return mat0,mat1 #--------------回归树所需子函数---------------# def regLeaf(dataSet): '''负责生成叶节点''' #当chooseBestSplit()函数确定不再对数据进行切分时,将调用本函数来得到叶节点的模型。 #在回归树中,该模型其实就是目标变量的均值。 return np.mean(dataSet[:,-1]) def regErr(dataSet): ''' 误差估计函数,该函数在给定的数据上计算目标变量的平方误差,这里直接调用均方差函数 ''' return var(dataSet[:,-1]) * shape(dataSet)[0]#返回总方差 #--------------回归树子函数 END --------------# def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): ''' 用最佳方式切分数据集和生成相应的叶节点 ''' #ops为用户指定参数,用于控制函数的停止时机 tolS = ops[0]; tolN = ops[1] #如果所有值相等则退出 if len(set(dataSet[:,-1].T.tolist()[0])) == 1: return None, leafType(dataSet) m,n = shape(dataSet) S = errType(dataSet) bestS = inf; bestIndex = 0; bestValue = 0 #在所有可能的特征及其可能取值上遍历,找到最佳的切分方式 #最佳切分也就是使得切分后能达到最低误差的切分 for featIndex in range(n-1): for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]): mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestValue = splitVal bestS = newS #如果误差减小不大则退出 if (S - bestS) < tolS: return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) #如果切分出的数据集很小则退出 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): return None, leafType(dataSet) #提前终止条件都不满足,返回切分特征和特征值 return bestIndex,bestValue def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): ''' 树构建函数 leafType:建立叶节点的函数 errType:误差计算函数 ops:包含树构建所需其他参数的元组 ''' #选择最优的划分特征 #如果满足停止条件,将返回None和某类模型的值 #若构建的是回归树,该模型是一个常数;如果是模型树,其模型是一个线性方程 feat, val = chooseBestSplit(dataSet, leafType, errType, ops) if feat == None: return val # retTree = {} retTree['spInd'] = feat retTree['spVal'] = val #将数据集分为两份,之后递归调用继续划分 lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTree def storeTree(inputTree, filename): with open(filename, 'wb') as fw: pickle.dump(inputTree, fw) if __name__ == '__main__': myDat = loadDataSet('ex00.txt') x=[x[0] for x in myDat] y=[y[1] for y in myDat] font=FontProperties(fname=r"c:\windows\fonts\simsun.ttc",size=14) plt.figure(figsize=(8,4)) plt.scatter(x,y) plt.xlabel("x") plt.ylabel("y") plt.title(u"基于CART算法构建回归树的简单数据集",fontproperties=font) # 在图上输出中文标题 plt.show() myMat = mat(myDat) retTree = createTree(myMat) print(retTree) storeTree(retTree, 'retTree.txt')
运行效果如下:
(1)输出数据集
(2)输出回归树
修改数据集为ex0.txt,输出结果如下:
(1)散点图
(2)回归树
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583
, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 1, 'spVal'
: 0.197834, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}