机器学习理论与实战——回归

按照《机器学习实战》的主线,结束有监督学习中关于分类的机器学习方法,进入回归部分。所谓回归就是数据进行曲线拟合,回归一般用来做预测,涵盖线性回归(经典最小二乘法)、局部加权线性回归、岭回归和逐步线性回归。先来看下线性回归,即经典最小二乘法,说到最小二乘法就不得说下线性代数,因为一般说线性回归只通过计算一个公式就可以得到答案,如(公式一)所示:


(公式一)

其中X是表示样本特征组成的矩阵,Y表示对应的值,比如房价,股票走势等,(公式一)得到是直接通过对(公式二)求导得到的,因为(公式二)是凸函数,导数等于零的点就是最小点。

(公式二)

不过并不是所有的码农能从(公式二)求导得到(公式一)的解,因此这里给出另外一个直观的解,直观理解建立起来后,后续几个回归就简单类推咯。从初中的投影点说起,如(图一)所示:

(图一)

在(图一)中直线a上离点b最近的点是点b在其上的投影,即垂直于a的交点p。p是b在a上的投影点。试想一下,如果我们把WX看成多维的a,即空间中的一个超面来代替二维空间中的直线,而y看成b,那现在要使得(公式二)最小是不是就是寻找(图一)中的e,即垂直于WX的垂线,因为只有垂直时e才最小。下面来看看如何通过寻找垂线并最终得到W。要寻找垂线,先从(图二)中的夹角theta 说起吧,因为当cos(theta)=0时,他们也就垂直了。下面来分析下直线或者向量之间的夹角,如(图二)所示:


(图二)

在(图二)中, 表示三角形 的斜边,那么:

角beta也可以得到同样的计算公式,接着利用三角形和差公式得到(公式三):


(公式三)

(公式三)表示的是两直线或者两向量之间的夹角公式,很多同学都学过。再仔细看下,发现分子其实是向量a,b之间的内积(点积),因此公式三变为简洁的(公式四)的样子:

(公式四)

接下来继续分析(图一)中的投影,为了方便观看,增加了一些提示如(图三)所示:

(图三)

在(图三)中,假设向量b在向量a中的投影为p(注意,这里都上升为向量空间,不再使用直线,因为(公式四)是通用的)。投影p和a 在同一方向上(也可以反方向),因此我们可以用一个系数乘上a来表示p,比如(图三)中的 ,有了投影向量p,那么我们就可以表示向量e,因为根据向量法则, ,有因为a和e垂直,因此 ,展开求得系数x,如(公式五)所示:


(公式五)

(公式五)是不是很像(公式一)?只不过公式一的分母写成了另外的形式,不过别急,现在的系数只是一个数字,因为a,b都是一个向量,我们要扩展一下,把a从向量扩展到子空间,因为(公式一)中的X是样本矩阵,矩阵有列空间和行空间,如(图四)所示:

(图四)

(图四)中的A表示样本矩阵X,假设它有两个列a1和a2,我们要找一些线性组合系数来找一个和(图三)一样的接受b 投影的向量,而这个向量通过矩阵列和系数的线性组合表示。求解的这个系数的思路和上面完全一样,就是寻找投影所在的向量和垂线e的垂直关系,得到系数,如(公式六)所示:

(公式六)

这下(公式六)和(公式一)完全一样了,基于最小二乘法的线性回归也就推导完成了,而局部加权回归其实只是相当于对不同样本之间的关系给出了一个权重,所以叫局部加权,如(公式七)所示:

(公式七)

而权重的计算可通过高斯核(高斯公式)来完成,核的作用就是做权重衰减,很多地方都要用到,表示样本的重要程度,一般离目标进的重要程度大些,高斯核可以很好的描述这种关系。如(公式八)所示,其中K是个超参数,根据情况灵活设置:

(公式八)

(图五)是当K分别为1.0, 0.01,0.003时的局部加权线性回归的样子,可以看出当K=1.0时,和线性回归没区别:

(图五)

而岭回归的样子如(公式九)所示:

(公式九)

岭回归主要是解决的问题就是当XX’无法求逆时,比如当特征很多,样本很少,矩阵X不是满秩矩阵,此时求逆会出错,但是通过加上一个对角为常量lambda的矩阵,就可以很巧妙的避免这个计算问题,因此会多一个参数lambda,lambda的最优选择由交叉验证(cross-validation)来决定,加上一个对角不为0的矩阵很形象的在对角上抬高了,因此称为岭。不同的lambda会使得系数缩减,如(图六)所示:

(图六)

说到系数缩减大家可能会觉得有奇怪,感觉有点类似于正则,但是这里只是相当于在(公式六)中增大分母,进而缩小系数,另外还有一些系数缩减的方法,比如直接增加一些约束,如(公式十)和(公式十一)所示:

(公式十)

(公式十一)

当线性回归增加了(公式十)的约束变得和桥回归差不多,系数缩减了,而如果增加了(公式十一)的约束时就是稀疏回归咯,(我自己造的名词,sorry),系数有一些0。

有了约束后,求解起来就不像上面那样直接计算个矩阵运算就行了,回顾第五节说中支持向量机原理,需要使用二次规划求解,不过仍然有一些像SMO算法一样的简化求解算法,比如前向逐步回归方法:

前向逐步回归的伪代码如(图七)所示,也不难,仔细阅读代码就可以理解:

(图七)

下面直接给出上面四种回归的代码:

[python] view plain copy
  1. from numpy import *

  2. def loadDataSet(fileName):#general function to parse tab -delimited floats
  3. numFeat = len(open(fileName).readline().split('\t')) -1 #get number of fields
  4. dataMat = []; labelMat = []
  5. fr = open(fileName)
  6. for line in fr.readlines():
  7. lineArr =[]
  8. curLine = line.strip().split('\t')
  9. for i in range(numFeat):
  10. lineArr.append(float(curLine[i]))
  11. dataMat.append(lineArr)
  12. labelMat.append(float(curLine[-1]))
  13. return dataMat,labelMat

  14. def standRegres(xArr,yArr):
  15. xMat = mat(xArr); yMat = mat(yArr).T
  16. xTx = xMat.T*xMat
  17. if linalg.det(xTx) ==0.0:
  18. print "This matrix is singular, cannot do inverse"
  19. return
  20. ws = xTx.I * (xMat.T*yMat)
  21. return ws

  22. def lwlr(testPoint,xArr,yArr,k=1.0):
  23. xMat = mat(xArr); yMat = mat(yArr).T
  24. m = shape(xMat)[0]
  25. weights = mat(eye((m)))
  26. for j in range(m):#next 2 lines create weights matrix
  27. diffMat = testPoint - xMat[j,:] #
  28. weights[j,j] = exp(diffMat*diffMat.T/(-2.0*k**2))
  29. xTx = xMat.T * (weights * xMat)
  30. if linalg.det(xTx) ==0.0:
  31. print "This matrix is singular, cannot do inverse"
  32. return
  33. ws = xTx.I * (xMat.T * (weights * yMat))
  34. return testPoint * ws

  35. def lwlrTest(testArr,xArr,yArr,k=1.0):#loops over all the data points and applies lwlr to each one
  36. m = shape(testArr)[0]
  37. yHat = zeros(m)
  38. for i in range(m):
  39. yHat[i] = lwlr(testArr[i],xArr,yArr,k)
  40. return yHat

  41. def lwlrTestPlot(xArr,yArr,k=1.0):#same thing as lwlrTest except it sorts X first
  42. yHat = zeros(shape(yArr)) #easier for plotting
  43. xCopy = mat(xArr)
  44. xCopy.sort(0)
  45. for i in range(shape(xArr)[0]):
  46. yHat[i] = lwlr(xCopy[i],xArr,yArr,k)
  47. return yHat,xCopy

  48. def rssError(yArr,yHatArr): #yArr and yHatArr both need to be arrays
  49. return ((yArr-yHatArr)**2).sum()

  50. def ridgeRegres(xMat,yMat,lam=0.2):
  51. xTx = xMat.T*xMat
  52. denom = xTx + eye(shape(xMat)[1])*lam
  53. if linalg.det(denom) == 0.0:
  54. print "This matrix is singular, cannot do inverse"
  55. return
  56. ws = denom.I * (xMat.T*yMat)
  57. return ws

  58. def ridgeTest(xArr,yArr):
  59. xMat = mat(xArr); yMat=mat(yArr).T
  60. yMean = mean(yMat,0)
  61. yMat = yMat - yMean #to eliminate X0 take mean off of Y
  62. #regularize X's
  63. xMeans = mean(xMat,0)#calc mean then subtract it off
  64. xVar = var(xMat,0) #calc variance of Xi then divide by it
  65. xMat = (xMat - xMeans)/xVar
  66. numTestPts = 30
  67. wMat = zeros((numTestPts,shape(xMat)[1]))
  68. for i in range(numTestPts):
  69. ws = ridgeRegres(xMat,yMat,exp(i-10))
  70. wMat[i,:]=ws.T
  71. return wMat

  72. def regularize(xMat):#regularize by columns
  73. inMat = xMat.copy()
  74. inMeans = mean(inMat,0)#calc mean then subtract it off
  75. inVar = var(inMat,0) #calc variance of Xi then divide by it
  76. inMat = (inMat - inMeans)/inVar
  77. return inMat

  78. def stageWise(xArr,yArr,eps=0.01,numIt=100):
  79. xMat = mat(xArr); yMat=mat(yArr).T
  80. yMean = mean(yMat,0)
  81. yMat = yMat - yMean #can also regularize ys but will get smaller coef
  82. xMat = regularize(xMat)
  83. m,n=shape(xMat)
  84. #returnMat = zeros((numIt,n)) #testing code remove
  85. ws = zeros((n,1)); wsTest = ws.copy(); wsMax = ws.copy()
  86. for i in range(numIt):
  87. print ws.T
  88. lowestError = inf;
  89. for j in range(n):
  90. for sign in [-1,1]:
  91. wsTest = ws.copy()
  92. wsTest[j] += eps*sign
  93. yTest = xMat*wsTest
  94. rssE = rssError(yMat.A,yTest.A)
  95. if rssE < lowestError:
  96. lowestError = rssE
  97. wsMax = wsTest
  98. ws = wsMax.copy()
  99. #returnMat[i,:]=ws.T
  100. #return returnMat

  101. #def scrapePage(inFile,outFile,yr,numPce,origPrc):
  102. # from BeautifulSoup import BeautifulSoup
  103. # fr = open(inFile); fw=open(outFile,'a') #a is append mode writing
  104. # soup = BeautifulSoup(fr.read())
  105. # i=1
  106. # currentRow = soup.findAll('table', r="%d" % i)
  107. # while(len(currentRow)!=0):
  108. # title = currentRow[0].findAll('a')[1].text
  109. # lwrTitle = title.lower()
  110. # if (lwrTitle.find('new') > -1) or (lwrTitle.find('nisb') > -1):
  111. # newFlag = 1.0
  112. # else:
  113. # newFlag = 0.0
  114. # soldUnicde = currentRow[0].findAll('td')[3].findAll('span')
  115. # if len(soldUnicde)==0:
  116. # print "item #%d did not sell" % i
  117. # else:
  118. # soldPrice = currentRow[0].findAll('td')[4]
  119. # priceStr = soldPrice.text
  120. # priceStr = priceStr.replace('$','') #strips out $
  121. # priceStr = priceStr.replace(',','') #strips out ,
  122. # if len(soldPrice)>1:
  123. # priceStr = priceStr.replace('Free shipping', '') #strips out Free Shipping
  124. # print "%s\t%d\t%s" % (priceStr,newFlag,title)
  125. # fw.write("%d\t%d\t%d\t%f\t%s\n" % (yr,numPce,newFlag,origPrc,priceStr))
  126. # i += 1
  127. # currentRow = soup.findAll('table', r="%d" % i)
  128. # fw.close()

  129. from time import sleep
  130. import json
  131. import urllib2
  132. def searchForSet(retX, retY, setNum, yr, numPce, origPrc):
  133. sleep(10)
  134. myAPIstr = 'AIzaSyD2cR2KFyx12hXu6PFU-wrWot3NXvko8vY'
  135. searchURL = 'https://www.googleapis.com/shopping/search/v1/public/products?key=%s&country=US&q=lego+%d&alt=json' % (myAPIstr, setNum)
  136. pg = urllib2.urlopen(searchURL)
  137. retDict = json.loads(pg.read())
  138. for i in range(len(retDict['items'])):
  139. try:
  140. currItem = retDict['items'][i]
  141. if currItem['product']['condition'] =='new':
  142. newFlag = 1
  143. else: newFlag = 0
  144. listOfInv = currItem['product']['inventories']
  145. for item in listOfInv:
  146. sellingPrice = item['price']
  147. if sellingPrice > origPrc *0.5:
  148. print "%d\t%d\t%d\t%f\t%f" % (yr,numPce,newFlag,origPrc, sellingPrice)
  149. retX.append([yr, numPce, newFlag, origPrc])
  150. retY.append(sellingPrice)
  151. except: print'problem with item %d' % i

  152. def setDataCollect(retX, retY):
  153. searchForSet(retX, retY, 8288,2006, 800,49.99)
  154. searchForSet(retX, retY, 10030,2002, 3096,269.99)
  155. searchForSet(retX, retY, 10179,2007, 5195,499.99)
  156. searchForSet(retX, retY, 10181,2007, 3428,199.99)
  157. searchForSet(retX, retY, 10189,2008, 5922,299.99)
  158. searchForSet(retX, retY, 10196,2009, 3263,249.99)

  159. def crossValidation(xArr,yArr,numVal=10):
  160. m = len(yArr)
  161. indexList = range(m)
  162. errorMat = zeros((numVal,30))#create error mat 30columns numVal rows
  163. for i in range(numVal):
  164. trainX=[]; trainY=[]
  165. testX = []; testY = []
  166. random.shuffle(indexList)
  167. for j in range(m):#create training set based on first 90% of values in indexList
  168. if j < m*0.9:
  169. trainX.append(xArr[indexList[j]])
  170. trainY.append(yArr[indexList[j]])
  171. else:
  172. testX.append(xArr[indexList[j]])
  173. testY.append(yArr[indexList[j]])
  174. wMat = ridgeTest(trainX,trainY) #get 30 weight vectors from ridge
  175. for k in range(30):#loop over all of the ridge estimates
  176. matTestX = mat(testX); matTrainX=mat(trainX)
  177. meanTrain = mean(matTrainX,0)
  178. varTrain = var(matTrainX,0)
  179. matTestX = (matTestX-meanTrain)/varTrain #regularize test with training params
  180. yEst = matTestX * mat(wMat[k,:]).T + mean(trainY)#test ridge results and store
  181. errorMat[i,k]=rssError(yEst.T.A,array(testY))
  182. #print errorMat[i,k]
  183. meanErrors = mean(errorMat,0)#calc avg performance of the different ridge weight vectors
  184. minMean = float(min(meanErrors))
  185. bestWeights = wMat[nonzero(meanErrors==minMean)]
  186. #can unregularize to get model
  187. #when we regularized we wrote Xreg = (x-meanX)/var(x)
  188. #we can now write in terms of x not Xreg: x*w/var(x) - meanX/var(x) +meanY
  189. xMat = mat(xArr); yMat=mat(yArr).T
  190. meanX = mean(xMat,0); varX = var(xMat,0)
  191. unReg = bestWeights/varX
  192. print "the best model from Ridge Regression is:\n",unReg
  193. print "with constant term: ",-1*sum(multiply(meanX,unReg)) + mean(yMat)
[python] view plain copy
  1. from numpy import *  
  2.   
  3. def loadDataSet(fileName):      #general function to parse tab -delimited floats  
  4.     numFeat = len(open(fileName).readline().split('\t')) - 1 #get number of fields   
  5.     dataMat = []; labelMat = []  
  6.     fr = open(fileName)  
  7.     for line in fr.readlines():  
  8.         lineArr =[]  
  9.         curLine = line.strip().split('\t')  
  10.         for i in range(numFeat):  
  11.             lineArr.append(float(curLine[i]))  
  12.         dataMat.append(lineArr)  
  13.         labelMat.append(float(curLine[-1]))  
  14.     return dataMat,labelMat  
  15.   
  16. def standRegres(xArr,yArr):  
  17.     xMat = mat(xArr); yMat = mat(yArr).T  
  18.     xTx = xMat.T*xMat  
  19.     if linalg.det(xTx) == 0.0:  
  20.         print "This matrix is singular, cannot do inverse"  
  21.         return  
  22.     ws = xTx.I * (xMat.T*yMat)  
  23.     return ws  
  24.   
  25. def lwlr(testPoint,xArr,yArr,k=1.0):  
  26.     xMat = mat(xArr); yMat = mat(yArr).T  
  27.     m = shape(xMat)[0]  
  28.     weights = mat(eye((m)))  
  29.     for j in range(m):                      #next 2 lines create weights matrix  
  30.         diffMat = testPoint - xMat[j,:]     #  
  31.         weights[j,j] = exp(diffMat*diffMat.T/(-2.0*k**2))  
  32.     xTx = xMat.T * (weights * xMat)  
  33.     if linalg.det(xTx) == 0.0:  
  34.         print "This matrix is singular, cannot do inverse"  
  35.         return  
  36.     ws = xTx.I * (xMat.T * (weights * yMat))  
  37.     return testPoint * ws  
  38.   
  39. def lwlrTest(testArr,xArr,yArr,k=1.0):  #loops over all the data points and applies lwlr to each one  
  40.     m = shape(testArr)[0]  
  41.     yHat = zeros(m)  
  42.     for i in range(m):  
  43.         yHat[i] = lwlr(testArr[i],xArr,yArr,k)  
  44.     return yHat  
  45.   
  46. def lwlrTestPlot(xArr,yArr,k=1.0):  #same thing as lwlrTest except it sorts X first  
  47.     yHat = zeros(shape(yArr))       #easier for plotting  
  48.     xCopy = mat(xArr)  
  49.     xCopy.sort(0)  
  50.     for i in range(shape(xArr)[0]):  
  51.         yHat[i] = lwlr(xCopy[i],xArr,yArr,k)  
  52.     return yHat,xCopy  
  53.   
  54. def rssError(yArr,yHatArr): #yArr and yHatArr both need to be arrays  
  55.     return ((yArr-yHatArr)**2).sum()  
  56.   
  57. def ridgeRegres(xMat,yMat,lam=0.2):  
  58.     xTx = xMat.T*xMat  
  59.     denom = xTx + eye(shape(xMat)[1])*lam  
  60.     if linalg.det(denom) == 0.0:  
  61.         print "This matrix is singular, cannot do inverse"  
  62.         return  
  63.     ws = denom.I * (xMat.T*yMat)  
  64.     return ws  
  65.       
  66. def ridgeTest(xArr,yArr):  
  67.     xMat = mat(xArr); yMat=mat(yArr).T  
  68.     yMean = mean(yMat,0)  
  69.     yMat = yMat - yMean     #to eliminate X0 take mean off of Y  
  70.     #regularize X's  
  71.     xMeans = mean(xMat,0)   #calc mean then subtract it off  
  72.     xVar = var(xMat,0)      #calc variance of Xi then divide by it  
  73.     xMat = (xMat - xMeans)/xVar  
  74.     numTestPts = 30  
  75.     wMat = zeros((numTestPts,shape(xMat)[1]))  
  76.     for i in range(numTestPts):  
  77.         ws = ridgeRegres(xMat,yMat,exp(i-10))  
  78.         wMat[i,:]=ws.T  
  79.     return wMat  
  80.   
  81. def regularize(xMat):#regularize by columns  
  82.     inMat = xMat.copy()  
  83.     inMeans = mean(inMat,0)   #calc mean then subtract it off  
  84.     inVar = var(inMat,0)      #calc variance of Xi then divide by it  
  85.     inMat = (inMat - inMeans)/inVar  
  86.     return inMat  
  87.   
  88. def stageWise(xArr,yArr,eps=0.01,numIt=100):  
  89.     xMat = mat(xArr); yMat=mat(yArr).T  
  90.     yMean = mean(yMat,0)  
  91.     yMat = yMat - yMean     #can also regularize ys but will get smaller coef  
  92.     xMat = regularize(xMat)  
  93.     m,n=shape(xMat)  
  94.     #returnMat = zeros((numIt,n)) #testing code remove  
  95.     ws = zeros((n,1)); wsTest = ws.copy(); wsMax = ws.copy()  
  96.     for i in range(numIt):  
  97.         print ws.T  
  98.         lowestError = inf;   
  99.         for j in range(n):  
  100.             for sign in [-1,1]:  
  101.                 wsTest = ws.copy()  
  102.                 wsTest[j] += eps*sign  
  103.                 yTest = xMat*wsTest  
  104.                 rssE = rssError(yMat.A,yTest.A)  
  105.                 if rssE < lowestError:  
  106.                     lowestError = rssE  
  107.                     wsMax = wsTest  
  108.         ws = wsMax.copy()  
  109.         #returnMat[i,:]=ws.T  
  110.     #return returnMat  
  111.   
  112. #def scrapePage(inFile,outFile,yr,numPce,origPrc):  
  113. #    from BeautifulSoup import BeautifulSoup  
  114. #    fr = open(inFile); fw=open(outFile,'a') #a is append mode writing  
  115. #    soup = BeautifulSoup(fr.read())  
  116. #    i=1  
  117. #    currentRow = soup.findAll('table', r="%d" % i)  
  118. #    while(len(currentRow)!=0):  
  119. #        title = currentRow[0].findAll('a')[1].text  
  120. #        lwrTitle = title.lower()  
  121. #        if (lwrTitle.find('new') > -1) or (lwrTitle.find('nisb') > -1):  
  122. #            newFlag = 1.0  
  123. #        else:  
  124. #            newFlag = 0.0  
  125. #        soldUnicde = currentRow[0].findAll('td')[3].findAll('span')  
  126. #        if len(soldUnicde)==0:  
  127. #            print "item #%d did not sell" % i  
  128. #        else:  
  129. #            soldPrice = currentRow[0].findAll('td')[4]  
  130. #            priceStr = soldPrice.text  
  131. #            priceStr = priceStr.replace('$','') #strips out $  
  132. #            priceStr = priceStr.replace(',','') #strips out ,  
  133. #            if len(soldPrice)>1:  
  134. #                priceStr = priceStr.replace('Free shipping', '') #strips out Free Shipping  
  135. #            print "%s\t%d\t%s" % (priceStr,newFlag,title)  
  136. #            fw.write("%d\t%d\t%d\t%f\t%s\n" % (yr,numPce,newFlag,origPrc,priceStr))  
  137. #        i += 1  
  138. #        currentRow = soup.findAll('table', r="%d" % i)  
  139. #    fw.close()  
  140.       
  141. from time import sleep  
  142. import json  
  143. import urllib2  
  144. def searchForSet(retX, retY, setNum, yr, numPce, origPrc):  
  145.     sleep(10)  
  146.     myAPIstr = 'AIzaSyD2cR2KFyx12hXu6PFU-wrWot3NXvko8vY'  
  147.     searchURL = 'https://www.googleapis.com/shopping/search/v1/public/products?key=%s&country=US&q=lego+%d&alt=json' % (myAPIstr, setNum)  
  148.     pg = urllib2.urlopen(searchURL)  
  149.     retDict = json.loads(pg.read())  
  150.     for i in range(len(retDict['items'])):  
  151.         try:  
  152.             currItem = retDict['items'][i]  
  153.             if currItem['product']['condition'] == 'new':  
  154.                 newFlag = 1  
  155.             else: newFlag = 0  
  156.             listOfInv = currItem['product']['inventories']  
  157.             for item in listOfInv:  
  158.                 sellingPrice = item['price']  
  159.                 if  sellingPrice > origPrc * 0.5:  
  160.                     print "%d\t%d\t%d\t%f\t%f" % (yr,numPce,newFlag,origPrc, sellingPrice)  
  161.                     retX.append([yr, numPce, newFlag, origPrc])  
  162.                     retY.append(sellingPrice)  
  163.         exceptprint 'problem with item %d' % i  
  164.       
  165. def setDataCollect(retX, retY):  
  166.     searchForSet(retX, retY, 8288200680049.99)  
  167.     searchForSet(retX, retY, 1003020023096269.99)  
  168.     searchForSet(retX, retY, 1017920075195499.99)  
  169.     searchForSet(retX, retY, 1018120073428199.99)  
  170.     searchForSet(retX, retY, 1018920085922299.99)  
  171.     searchForSet(retX, retY, 1019620093263249.99)  
  172.       
  173. def crossValidation(xArr,yArr,numVal=10):  
  174.     m = len(yArr)                             
  175.     indexList = range(m)  
  176.     errorMat = zeros((numVal,30))#create error mat 30columns numVal rows  
  177.     for i in range(numVal):  
  178.         trainX=[]; trainY=[]  
  179.         testX = []; testY = []  
  180.         random.shuffle(indexList)  
  181.         for j in range(m):#create training set based on first 90% of values in indexList  
  182.             if j < m*0.9:   
  183.                 trainX.append(xArr[indexList[j]])  
  184.                 trainY.append(yArr[indexList[j]])  
  185.             else:  
  186.                 testX.append(xArr[indexList[j]])  
  187.                 testY.append(yArr[indexList[j]])  
  188.         wMat = ridgeTest(trainX,trainY)    #get 30 weight vectors from ridge  
  189.         for k in range(30):#loop over all of the ridge estimates  
  190.             matTestX = mat(testX); matTrainX=mat(trainX)  
  191.             meanTrain = mean(matTrainX,0)  
  192.             varTrain = var(matTrainX,0)  
  193.             matTestX = (matTestX-meanTrain)/varTrain #regularize test with training params  
  194.             yEst = matTestX * mat(wMat[k,:]).T + mean(trainY)#test ridge results and store  
  195.             errorMat[i,k]=rssError(yEst.T.A,array(testY))  
  196.             #print errorMat[i,k]  
  197.     meanErrors = mean(errorMat,0)#calc avg performance of the different ridge weight vectors  
  198.     minMean = float(min(meanErrors))  
  199.     bestWeights = wMat[nonzero(meanErrors==minMean)]  
  200.     #can unregularize to get model  
  201.     #when we regularized we wrote Xreg = (x-meanX)/var(x)  
  202.     #we can now write in terms of x not Xreg:  x*w/var(x) - meanX/var(x) +meanY  
  203.     xMat = mat(xArr); yMat=mat(yArr).T  
  204.     meanX = mean(xMat,0); varX = var(xMat,0)  
  205.     unReg = bestWeights/varX  
  206.     print "the best model from Ridge Regression is:\n",unReg  
  207.     print "with constant term: ",-1*sum(multiply(meanX,unReg)) + mean(yMat)  


参考文献:

[1] machine learning in action.Peter Harrington

[2]Linear Algebra and Its Applications_4ed.Gilbert_Strang


转载请注明来源:http://blog.csdn.net/cuoqu/article/details/9387305

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值