数据集中经常包含一些复杂的相互关系,使得输入数据和目标变量之间呈现非线性关系。对复杂的关系建模,一种方式用树对预测值分段,包括分段常数(回归树)和分段直线(模型树)。
CART算法可以用于构建二元树并处理离散型或连续型数据的切分,对于过拟合可以采取剪枝的办法:预剪枝(在树的构建过程中剪枝)和后剪枝(当树构建完毕进行剪枝)。
树回归的具体实现代码如下:
'''
Tree-Based Regression Methods
'''
from numpy import *
def loadDataSet(fileName): #general function to parse tab -delimited floats
dataMat = [] #assume last column is target value
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = map(float,curLine) #map all elements to float()
dataMat.append(fltLine)
return dataMat
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
return mat0,mat1
def regLeaf(dataSet):#returns the value used for each leaf
return mean(dataSet[:,-1])
def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]
def linearSolve(dataSet): #helper function used in two places
m,n = shape(dataSet)
X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
xTx = X.T*X
if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse,\n\
try increasing the second value of ops'