1、CART生成
回归树采用平方误差最小化准则。
分类树采用基尼指数最小化准则。
from numpy import *
from matplotlib import pyplot as plt
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine)) # 将数据映射成浮点型,返回的是map的地址 dataMat.append(fltLine)
dataMat.append(list(fltLine))
return dataMat
def plotData(dataMat):
x = dataMat[:, 0].tolist()
y = dataMat[:, 1].tolist()
plt.scatter(x, y)
plt.title('DataSet')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
def plotData2(dataMat):
x = dataMat[:, 1].tolist()
y = dataMat[:, 2].tolist()
plt.scatter(x, y)
plt.title('DataSet')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
def binSplitDataSet(dataSet, feature, value):
"""
根据某个特征及值对数据进行切分
:param dataSet: 数据集
:param feature: 特征
:param value: 特征值
:return: 切分后的数据
"""
mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
return mat0, mat1
def regLeaf(dataSet):
"""
生成叶节点
:param dataSet: 数据集
:return: 使用均值作为叶节点的值
"""
return mean(dataSet[:, -1])
# 误差估计函数
def regErr(dataSet):
"""
误差估计函数
:param dataSet: 数据集
:return: 总方差
"""
return var(dataSet[:, -1]) * shape(dataSet)[0]
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
"""
找到数据的最佳二元切分方式函数
:param dataSet:数据组
:param leafType:生成叶节点
:param errType:误差估计函数
:param ops:用户自定义参数构成的元组
:return:
bestIndex:最佳切分特征的下标
bestValue:最优切分值
"""
"""
伪代码:
首先判断是否所有值相等:
否:计算误差值,初始化最佳误差,最优切分特征下标,最优切分值
遍历每个特征:
遍历特征的每个值:
将数据切分为两部分
判断是否小于最少切分样本值:
是:跳出循环
否:计算切分后误差
如果切分后误差小于当前最佳误差,则更新最优切分特征下标,最优切分值,最优误差
如果误差下降值小于最小误差值,则返回最佳切分特征值和最优切分值
"""
tolS = ops[0];
tolN = ops[1] # tolS是允许的最小误差下降值,tolN是切分的最少样本
if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 如果所有值相等则退出,根据set特性
return None, leafType(dataSet)
m, n = shape(dataSet)
S = errType(dataSet) #计算误差值
bestS = inf;
bestIndex = 0;
bestValue = 0
for featIndex in range(n - 1): #遍历每个特征值
for splitVal in set((dataSet[:, featIndex].T.tolist())[0]): # 遍历特征的每个值
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) #将数据集切分为两部分
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue #如果小于切分的最少样本数则直接跳到下一次循环
newS = errType(mat0) + errType(mat1) # 计算新的误差
if newS < bestS: # 如果小于最佳误差则更新
bestIndex = featIndex
bestValue = splitVal
bestS = newS
# 如果误差下降值小于允许的最小误差下降值,则不需要切分
if (S - bestS) < tolS:
return None, leafType(dataSet)
# 切分数据集
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
# 如果切分后的数据集个数小于tolN,则不进行切分
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex, bestValue
def createTree(dataSet, leafType=regLeaf, errType=regErr,ops=(1, 4)):
"""
建树
:param dataSet:数据集
:param leafType:生成叶子节点函数
:param errType:误差函数
:param ops:用户自定义参数,对树进行预剪枝
:return:树
"""
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None: return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
# 测试
if __name__ == '__main__':
ex00Data = loadDataSet('ex00.txt')
ex00Mat = mat(ex00Data)
plotData(ex00Mat)
myTree1 = createTree(ex00Mat)
print(myTree1)
ex0Data = loadDataSet('ex0.txt')
ex0Mat = mat(ex0Data)
plotData2(ex0Mat)
myTree2 = createTree(ex0Mat)
print(myTree2)
print(createTree(ex0Mat, ops=(0,1)))
2、树剪枝
(1)预剪枝
if __name__ == '__main__':
ex2Data = loadDataSet('ex2.txt')
ex2Mat = mat(ex2Data)
plotData(ex2Mat)
ex2Tree = createTree(ex2Mat)
print(ex2Tree) #非常多的叶子节点
print(createTree(ex2Mat,ops=(10000,4))) #只有两个叶子节点,tolS对误差的数据集非常敏感
(2)后剪枝
def isTree(obj):
"""
函数说明:判断是否是一棵树
:param obj:
:return:
"""
return (type(obj).__name__=='dict')
def getMean(tree):
"""
函数说明:递归函数,对树进行塌陷处理(返回树平均值)
:param tree: 树
:return: 树的左右子树的均值
"""
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
return (tree['left'] + tree['right']) / 2.0
def prnue(tree, testData):
"""
函数说明:后剪枝函数,使用测试数据对树进行剪枝
:param tree: 待剪枝的树
:param testData: 测试数据
:return: 剪枝后的树
"""
"""
树的后剪枝:
伪代码:
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算将当前两个叶节点合并后的误差
计算没有合并时的误差
如果合并可以降低误差则进行合并
"""
#如果测试数据集为空,进行塌陷处理
if(shape(testData[0]) == 0): return getMean(tree)
# if isTree(tree['left']) and isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prnue(tree['left'], lSet) #左子集不是单个节点,则对左子树进行剪枝
if isTree(tree['right']): tree['right'] = prnue(tree['right'], rSet) #右子集不是单个节点,则对右子树进行剪枝
if not isTree(tree['left']) and not isTree(tree['right']): #左右子集都是单个节点
errNoMerge = sum(power((lSet[:,-1] - tree['left']),2)) + sum(power((rSet[:,-1] - tree['right']),2)) #未合并的误差
treeMean = getMean(tree)
errMerge = sum(power(testData[:,-1] - treeMean, 2)) #合并后误差
print(errNoMerge, errMerge)
if errMerge < errNoMerge: #如果合并后误差小于合并前误差则合并
print("merging")
return treeMean
else: return tree
else: return tree
if __name__ == '__main__':
ex2Data = loadDataSet('ex2.txt')
ex2Mat = mat(ex2Data)
plotData(ex2Mat)
ex2Tree = createTree(ex2Mat)
print(ex2Tree)
ex2TestData = loadDataSet('ex2test.txt')
prnueTree = prnue(ex2Tree, mat(ex2TestData))
print(prnueTree)
3、进行预测
简单线性回归、回归树、模型树的对比
使用决定系数R^2值来判断预测效果,越接近于1越好。