模型树
用树来对数据建模,除了把叶节点简单的设定为常数值之外,还有一种方法是把叶节点设定为分段线性函数。这里所谓的分段线性是指模型由多个线性片段组成。
上图就是两个线性模型,数据集中0.0-0.3以某个线性模型建模,另一部分以另外一个线性模型建模。
很显然,两条直线比很多节点组成的一个大树更容易解释。模型树的可解释性是它由于回归树的特点之一。前面的代码稍加修改就可以在叶节点生成线性模型而不是常数值。
为了找到最佳切分,计算误差应该先用线性模型来对他进行拟合,然后计算真实的目标值与模型预测值之间的差值。最后将这些差值的平方求和就得到了所需的误差。
# 模型树的叶节点生成函数
def linearSolve(dataset):
m,n = shape(dataset)
# 将数据集格式化成目标变量Y和自变量X。用于执行简单的线性回归。
X = mat(ones((m,n)))
Y = mat(ones((m,1)))
X[:,1:n] = dataset[:,0:n-1]
Y = dataset[:,-1]
xTx = X.T*X
if linalg.det(xTx) == 0.0:
raise NameError('this matrix is singular, cannot do inverse')
ws = xTx.I * (X.T * Y)
return ws, X, Y
# 当数据不在需要切分的时候就负责生成叶节点的模型。
def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)
return ws
# 在给定的数据集上计算误差,会被chooseBestSplit()调用来找到最佳的切分。
def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
# 平方差
return sum(power(Y - yHat, 2))
执行结果如下:
可以看到该代码以0.285477为界创建了两个模型。而图9-4的数据实际在0.处分段。createTree()生成的这两个线性模型分别是0.0016386+11.965x 和 3.4688+1.1852x。实际上该图形是由模型y=3.5+1.0x和y=0+12x再加上高斯噪声生成的。下图是由数据生成的线性模型:
模型树,回归树以及其他的模型,哪种更好呢?一个比较客观的比较方法是计算相关系数,也称为值。该相关系数可以通过Numpy库中的命令corrcoef(yHat, y, rowvar=0)来求解,其中yHat是预测值,y是真实值。
使用python的Tkinter库创建GUI
同时支持数据呈现和用户交互的方式就是构建一个图形用户界面。
首先介绍利用现有的tkinter来构建GUI,之后介绍如何在tkinter和绘图库之间交互。最后通过创建gui使人们能够探索模型树和回归树的奥秘。
1. 用Tkinter创建GUI
先从简单的hello world开始:
出现小窗口及文字。
tkinter的gui由一些小部件组成,指的是文本框,按钮,标签和复选按钮等对象。在上面的例子中,标签mylabel就是其中的唯一小部件。使用mylabel的grid()方法时,就等于把mylabel的位置告诉了布局管理器。tkinter中提供了集中不同的布局管理器,grid方法会把小部件安排在一个二维的表格中。下面将所需的小部件集成在一起构成树管理器。
from tkinter import *
from numpy import *
import regTrees
def reDraw(tolS, tolN):
pass
def drawNewTree():
pass
root = Tk()
Label(root, text='Plot Place Holder').grid(row=0, columnspan=3)
Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root) # 文本输入框
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = IntVar()# 为了读取复选按钮的状态
# 复选按钮
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = mat(regTrees.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
root.mainloop()
运行效果类似如此。
2. 集成Matplotlib和Tkinter
我们通过修改matplotlib的后端(仅在我们的gui上)达到在tkinter的gui上绘图的目的。
matplotlib的构建程序包含一个前端,也就是面向用户的一些代码,例如plot()和scatter()方法。事实上,他同时也创建了一个后端,用于实现绘图和不同应用之间的接口。通过改变后端可以将图像绘制在PNG, PDF,SVG等格式的文件上。下面将设置后端为TkAgg(Agg是一个C++的库,可以从图像上创建光栅图)。TkAgg可以在所选GUI框架上调用Agg,把Agg呈现在画布上。
from tkinter import *
from numpy import *
import regTrees
import matplotlib
matplotlib.use('TkAgg') # 设置后端
# 将TkAgg和matplotlib图连起来
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
# 把树绘制出来
def reDraw(tolS,tolN):
reDraw.f.clf() # clear the figure
reDraw.a = reDraw.f.add_subplot(111)
# 检查复选框是否被选中
if chkBtnVar.get():
if tolN < 2: tolN = 2
myTree=regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf,\
regTrees.modelErr, (tolS,tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat, \
regTrees.modelTreeEval)
else:
myTree=regTrees.createTree(reDraw.rawDat, ops=(tolS,tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:,0], reDraw.rawDat[:,1], s=5) #use scatter for data set
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
reDraw.canvas.show()
def getInputs():
try: tolN = int(tolNentry.get()) #得到用户输入的文本
except:
tolN = 10
print ("enter Integer for tolN")
tolNentry.delete(0, END)
tolNentry.insert(0,'10')
try: tolS = float(tolSentry.get())
except:
tolS = 1.0
print ("enter Float for tolS")
tolSentry.delete(0, END)
tolSentry.insert(0,'1.0')
return tolN,tolS
# 点击ReDraw按钮时就会调用该函数
def drawNewTree():
# 拿到输入框的值
tolN,tolS = getInputs()#get values from Entry boxes
# 画图
reDraw(tolS,tolN)
root = Tk()
reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
Label(root, text='Plot Place Holder').grid(row=0, columnspan=3)
Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root) # 文本输入框
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = IntVar()# 为了读取复选按钮的状态
# 复选按钮
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = mat(regTrees.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
root.mainloop()
总结:
数据集中一些复杂的关系,使得输入与目标变量之间呈现非线性关系。对于这些复杂的关系,一种可行的方式就是就是使用树来对预测值分段,包括分段常数或则分段直线。一般使用树结构建模。
CART算法可以用于构建二元树并处理离散型或连续型数据的切分。若使用不同的误差准则,就可以通过CART算法构建模型树和回归树。该算法构建的树可能会倾向于过拟合。一颗过拟合的树常常十分复杂,我们就是用剪枝技术来处理解决这个问题。