8.1 用线性回归找到最佳拟合直线
1.
以下代码用到的注释
1.readline()每次只读取一行,只需读取一行计算特征值节省内存
2. readlines()一次读取整个文件,自动将文件内容分析成一个行的列表3.#strip():返回移除字符串头尾指定的字符生成的新字符串
4.#split()通过指定分隔符对字符串进行切片,如果参数num 有指定值,则仅分隔 num 个子字符串
5.range() 函数可创建一个整数列表,一般用在 for 循环中。
>>>range(10) # 从 0 开始到 10
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> range(1, 11) # 从 1 开始到 11
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> range(0, 30, 5) # 步长为 5
[0, 5, 10, 15, 20, 25]
6.append() 方法用于在列表末尾添加新的对象
extend() 函数用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)
append(), extend()都是list的函数,
所以要对矩阵进行类似操作的时候,需要现将矩阵通过上述方法转化成list,再进行操作。
7.numpy中的mat使list变换成了矩阵
from numpy import *
#解析文件中的数据为适合机器处理的形式
def loadDataSet(fileName):
numFeat=len(open(fileName).readline().split('\t'))-1
# .readline()每次只读取一行,只需读取一行计算特征值洁身内存
dataMat=[]; labelMat=[]
fr=open(fileName)
for line in fr.readlines():
lineArr=[]
curLine=line.strip().split('\t')
for i in range(numFeat):
lineArr.append(float(curLine[i]))
dataMat.append(lineArr)
labelMat.append(float(curLine[-1]))
return dataMat,labelMat
#标准线性回归算法:ws=(X.T*X).I*(X.T*Y)
def standRegres(xArr,yArr):
#将列表形式的数据转为numpy矩阵形式
xMat=mat(xArr);yMat=mat(yArr).T
#求矩阵的内积
xTx=xMat.T*xMat
#numpy线性代数库linalg
#调用linalg.det()计算矩阵行列式
#计算矩阵行列式是否为0
if linalg.det(xTx)==0.0:
print('This matrix is singular,cannot do inverse')
return
#如果可逆,根据公式计算回归系数
ws=xTx.I*(xMat.T*yMat)
#可以用yHat=xMat*ws计算实际值y的预测值
#返回归系数
return ws
2.在命令行写入如下代码
>>>import regression
>>>from numpy import *
>>>xArr,yArr = >>>regression.loadDataSet('ex.txt')
>>>ws =regression.standRegres(xarr,yarr)
>>>ws
yHat是预测出的值
>>>xMat = mat(xarr)
>>>yMat = mat(yarr)
>>>yHat = xMat*ws
>>>import matplotlib.pyplot as plt
>>>fig = plt.figure()
>>>ax = fig.add_subplot(111)
>>>ax.scatter(xMat[:,1].flatten().A[0],yMat.T[:,0].flatten().A[0])
Out[44]: <matplotlib.collections.PathCollection at 0x1fd29ef7048>
>>>xCopy = xMat.copy()
>>>xCopy.sort(0)
>>>yHat= xCopy*ws
>>>ax.plot(xCopy[:,1],yHat)
Out[48]: [<matplotlib.lines.Line2D at 0x1fd29f1b320>]
>>>plt.show()
涉及到的一点matplotlib的知识
1.plt.figure()指的是画出一个画布
subplot(numRows, numCols,plotNum)
图表的整个绘图区域被分成numRows行和numCols列,plotNum参数指定创建的Axes对象所在的区域
plot函数用于做xy的关系图(直线图)
2.scatter函数用于画散点图
有关scatter的详细介绍:http://blog.csdn.net/anneqiqi/article/details/64125186
3.a是个矩阵或者数组,a.flatten()就是把a降到一维,默认是按横的方向降
>>> a = np.array([[1,2], [3,4]])
>>> a.flatten()
array([1, 2, 3, 4])
>>> a.flatten('F') #按竖的方向降
array([1, 3, 2, 4])