《机器学习实战笔记--第二部分 利用回归预测数值型数据:树回归2》

  模型树

    用树来对数据建模,除了把叶节点简单的设定为常数值之外,还有一种方法是把叶节点设定为分段线性函数。这里所谓的分段线性是指模型由多个线性片段组成。

    

    上图就是两个线性模型,数据集中0.0-0.3以某个线性模型建模,另一部分以另外一个线性模型建模。

    很显然,两条直线比很多节点组成的一个大树更容易解释。模型树的可解释性是它由于回归树的特点之一。前面的代码稍加修改就可以在叶节点生成线性模型而不是常数值。

    为了找到最佳切分,计算误差应该先用线性模型来对他进行拟合,然后计算真实的目标值与模型预测值之间的差值。最后将这些差值的平方求和就得到了所需的误差。

# 模型树的叶节点生成函数
def linearSolve(dataset):
    m,n = shape(dataset)
    # 将数据集格式化成目标变量Y和自变量X。用于执行简单的线性回归。
    X = mat(ones((m,n)))
    Y = mat(ones((m,1)))
    X[:,1:n] = dataset[:,0:n-1]
    Y = dataset[:,-1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError('this matrix is singular, cannot do inverse')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y

# 当数据不在需要切分的时候就负责生成叶节点的模型。
def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws

# 在给定的数据集上计算误差,会被chooseBestSplit()调用来找到最佳的切分。
def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    # 平方差
    return sum(power(Y - yHat, 2))

    执行结果如下:

    

    可以看到该代码以0.285477为界创建了两个模型。而图9-4的数据实际在0.处分段。createTree()生成的这两个线性模型分别是0.0016386+11.965x 和 3.4688+1.1852x。实际上该图形是由模型y=3.5+1.0x和y=0+12x再加上高斯噪声生成的。下图是由数据生成的线性模型:

    

    模型树,回归树以及其他的模型,哪种更好呢?一个比较客观的比较方法是计算相关系数,也称为值。该相关系数可以通过Numpy库中的命令corrcoef(yHat, y, rowvar=0)来求解,其中yHat是预测值,y是真实值。

    

使用python的Tkinter库创建GUI

    同时支持数据呈现和用户交互的方式就是构建一个图形用户界面。

        

    首先介绍利用现有的tkinter来构建GUI,之后介绍如何在tkinter和绘图库之间交互。最后通过创建gui使人们能够探索模型树和回归树的奥秘。

   1. 用Tkinter创建GUI

     先从简单的hello world开始:

       

    出现小窗口及文字。

    tkinter的gui由一些小部件组成,指的是文本框,按钮,标签和复选按钮等对象。在上面的例子中,标签mylabel就是其中的唯一小部件。使用mylabel的grid()方法时,就等于把mylabel的位置告诉了布局管理器。tkinter中提供了集中不同的布局管理器,grid方法会把小部件安排在一个二维的表格中。下面将所需的小部件集成在一起构成树管理器。

from tkinter import *
from numpy import *
import regTrees

def reDraw(tolS, tolN):
    pass

def drawNewTree():
    pass

root = Tk()
Label(root, text='Plot Place Holder').grid(row=0, columnspan=3)
Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root) # 文本输入框
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)

chkBtnVar = IntVar()# 为了读取复选按钮的状态
# 复选按钮
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

reDraw.rawDat = mat(regTrees.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
               
root.mainloop()

    

    运行效果类似如此。


   2. 集成Matplotlib和Tkinter

     我们通过修改matplotlib的后端(仅在我们的gui上)达到在tkinter的gui上绘图的目的。

     matplotlib的构建程序包含一个前端,也就是面向用户的一些代码,例如plot()和scatter()方法。事实上,他同时也创建了一个后端,用于实现绘图和不同应用之间的接口。通过改变后端可以将图像绘制在PNG, PDF,SVG等格式的文件上。下面将设置后端为TkAgg(Agg是一个C++的库,可以从图像上创建光栅图)。TkAgg可以在所选GUI框架上调用Agg,把Agg呈现在画布上。

from tkinter import *
from numpy import *
import regTrees

import matplotlib
matplotlib.use('TkAgg') # 设置后端
# 将TkAgg和matplotlib图连起来
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

# 把树绘制出来
def reDraw(tolS,tolN):
    reDraw.f.clf()        # clear the figure
    reDraw.a = reDraw.f.add_subplot(111)
    # 检查复选框是否被选中
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree=regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf,\
                                   regTrees.modelErr, (tolS,tolN))
        yHat = regTrees.createForeCast(myTree, reDraw.testDat, \
                                       regTrees.modelTreeEval)
    else:
        myTree=regTrees.createTree(reDraw.rawDat, ops=(tolS,tolN))
        yHat = regTrees.createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:,0], reDraw.rawDat[:,1], s=5) #use scatter for data set
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
    reDraw.canvas.show()


def getInputs():
    try: tolN = int(tolNentry.get()) #得到用户输入的文本
    except: 
        tolN = 10 
        print ("enter Integer for tolN")
        tolNentry.delete(0, END)
        tolNentry.insert(0,'10')
    try: tolS = float(tolSentry.get())
    except: 
        tolS = 1.0 
        print ("enter Float for tolS")
        tolSentry.delete(0, END)
        tolSentry.insert(0,'1.0')
    return tolN,tolS

# 点击ReDraw按钮时就会调用该函数
def drawNewTree():
    # 拿到输入框的值
    tolN,tolS = getInputs()#get values from Entry boxes
    # 画图
    reDraw(tolS,tolN)

root = Tk()

reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
Label(root, text='Plot Place Holder').grid(row=0, columnspan=3)
Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root) # 文本输入框
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)

chkBtnVar = IntVar()# 为了读取复选按钮的状态
# 复选按钮
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

reDraw.rawDat = mat(regTrees.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
               
root.mainloop()

    

    

    总结:

    数据集中一些复杂的关系,使得输入与目标变量之间呈现非线性关系。对于这些复杂的关系,一种可行的方式就是就是使用树来对预测值分段,包括分段常数或则分段直线。一般使用树结构建模。

    CART算法可以用于构建二元树并处理离散型或连续型数据的切分。若使用不同的误差准则,就可以通过CART算法构建模型树和回归树。该算法构建的树可能会倾向于过拟合。一颗过拟合的树常常十分复杂,我们就是用剪枝技术来处理解决这个问题。

    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值