这里我就不介绍局部加权线性回归这个算法的原理了,很多不错的博客介绍了这个很优秀的算法。它很适用于非线性的回归问题,吴恩答在斯坦福《机器学习》课程,也介绍了他使用了这个算法设计了一个能“倒着“飞的无人机。详细这个算法的说明可以上网易公开课观看,讲的很详细。
先贴代码:
import pymysql import matplotlib.pyplot as plt from numpy import * from scipy import * def lwlr(testPoint,xArr,yArr,k=1.0):#这个函数是为了估计一个点的值,xArr和yArr分别代表的是画图的x轴和y轴,也就是我们设的时间和车流量值 xMat=mat(xArr) yMat=mat(yArr).T m=shape(xMat)[0]#读取xMat里面的行数 weights=mat(eye((m))) for i in range(m): diffMat = testPoint - xMat[i,:] weights[i,i]=exp(diffMat*diffMat.T/(-2.0*k**2)) xTx=xMat.T*weights*xMat if linalg.det(xTx) ==0.0: print('奇异矩阵无法计算') return theta=xTx.I*(xMat.T*(weights*yMat.T)) return testPoint*theta def lwlrTest(testArr,xArr,yArr,k=1.0): m=shape(testArr)[0] yHat=zeros(m) for i in range(m): yHat[i]=lwlr(testArr[i],xArr,yArr,k) return yHat def datasetsX():#从数据库获取数据 conn=pymysql.connect(host='localhost',user='root',passwd='123456',db='cardb',port=3306) cur=conn.cursor()#获取游标 cur.execute("select minute from data") xData=cur.fetchall() cur.close() conn.close() return list(xData) def datasetsY1():#从一个提取车流量 conn = pymysql.connect(host='localhost', user='root', passwd='123456',db='cardb', port=3306) cur = conn.cursor() # 获取游标 cur.execute("select num1 from data") yData1=cur.fetchall() cur.close() conn.close() return list(yData1) def datasetsY2():#从反方向提取车流量 conn=pymysql.connect(host='localhost', user='root',passwd='123456', db='cardb', port=3306) cur=conn.cursor() cur.execute("select num2 from data") yData2=cur.fetchall() cur.close() conn.close() return list(yData2) xArr= datasetsX() yArr1=datasetsY1() yArr2=datasetsY2() yHat1=lwlrTest(xArr,xArr,yArr1,0.8) yHat2=lwlrTest(xArr,xArr,yArr2,0.8) xMat = mat(xArr) strInd = xMat[:, 0].argsort(0) xSort = xMat[strInd][:, 0, :] fig = plt.figure() ax = fig.add_subplot(2, 1, 1) ax.plot(xSort[:, 0], yHat1[strInd]) ax.scatter(xMat[:, 0].flatten().A[0], mat(yArr1).T.flatten().A[0], s = 2, c = 'red') bx = fig.add_subplot(2, 1, 2) bx.plot(xSort[:, 0], yHat2[strInd]) bx.scatter(xMat[:, 0].flatten().A[0], mat(yArr2).T.flatten().A[0], s = 2, c = 'red') plt.show()
因为我的程序是为了设计智能交通系统,所以从数据库里提取流量信息。大家可以选择提取数据的方法,但一定要注意矩阵运算有严格的格式要求,在运算的时候一定要先把数据统一好,这里我是用mat函数把数据转化为矩阵形式。这里的算法实现需要一定的概率论和矩阵论知识,比如我在矩阵相乘的这个过程中少加了一个转置符号所以调试了两天,最后复习矩阵相关的知识,一步步复查才检查出错误。相关知识很重要,不然连错误在哪都不知道。供大家相互交流。
!!!!!!!warning:这个算法能实现有一定规律的非线性回归的问题,但是对于那种完全随机的问题,这个算法没有一点用,因为预测只能找出规律,不能预测没有规律的事物。这个算法在数据量超级大的时候,它的计算代价比较大,具体优化可以看 kd-tree 相关知识。