写一个简单的线性回归模型。
看了一下斯坦福 Andrew Ng的机器学习。陆续的会将其中一些算法的的程序写一写 。谈及机器学习,其中两个非常重要的问题,即回归问题和分类问题。在计算机视觉中都有很多的应用。
主题:线性回归模型。
以上是理论部分 ,下面贴出程序(python):
import numpy as np
import matplotlib.pyplot as plt
class LINREG(object):
def __init__(self,data,label,alpha):
self.x =np.vstack((np.ones((1,data.size)),data) ) #construct train example
self.y =label
self.alpha =alpha
self.theta=10*np.random.random((1,2))
self.second_term = np.ones((1,2))
def hypothesis(self,x):
h=self.theta.dot(self.x) #row =1 col not change
return h
def grident_descent(self,h,y):
self.second_term = np.sum( (h-y)* self.x,axis = 1)
# print self.second_term.shape
self.theta = self.theta - self.alpha*self.second_term
def printf(self):
print self.theta
def learn(self):
count = 0
while count<1000:
count=count+1
h=self.hypothesis(self.x)
self.grident_descent(h,self.y)
if __name__ =='__main__':
trainData= np.array([[1,2,3,4]])
trainLabel= np.array([[2,5,6,7]])
line_regression = LINREG(trainData,trainLabel,0.01)# data label learning rate
line_regression.learn()
line_regression.printf() #print theta after learning
testdata = np.arange(0,9,1)
x =np.vstack((np.ones((1,testdata.size)),testdata) ) # construct test data
y=line_regression.theta.dot(x) #the result of test data
# draw the result
plt.plot(testdata,y[0])
plt.plot(trainData[0],trainLabel[0],'r^')
plt.axis([0,10,0,10])
plt.xlabel('x')
plt.ylabel('y')
plt.show()
试验结果图如下:
reference:
斯坦福 机器学习 课程