参考的文章:
http://blog.csdn.net/lulei1217/article/details/49385531
http://blog.csdn.net/LULEI1217/article/details/49386295
我将该作者上面两篇文章中的代码修改后,变成下面的样子。数据集在附件中。
#coding:utf-8
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import datasets, linear_model
'''
Created on 2016年12月9日
多维向量的线性回归
href:http://blog.csdn.net/LULEI1217/article/details/49386295
'''
def get_data(file_name,feature_cols):
#负责读取 csv文件.文件的第一行是列名。本例中的数据集在附件中
data = pd.read_csv(file_name)
X_parameter = data[feature_cols]
Y_parameter = data['Sales']
return X_parameter,Y_parameter
#输入参数为X轴的数据集,Y轴的数据集,假设的X的值
#输出是X的预测Y
def linear_model_main(X_parameters,Y_parameters,predict_value):
#调用api进行线性回归
regr = linear_model.LinearRegression()
regr.fit(X_parameters, Y_parameters)
#获得预测输出
predict_outcome = regr.predict(predict_value)
predictions = {}
predictions['intercept'] = regr.intercept_
predictions['coefficient'] = regr.coef_
predictions['predicted_value'] = predict_outcome
return predictions
#绘制效果比较图
def show_comparison(Y_parameters,Y_pred):
plt.figure()
plt.plot(range(len(Y_pred)),Y_pred,'b',label="predict")
plt.plot(range(len(Y_pred)),Y_parameters,'r',label="reality")
plt.legend(loc="upper right") #显示图中的标签
plt.xlabel("the sequence number of sales")
plt.ylabel('value of sales')
plt.show()
#评价算法效果的指标RMSE(Root Mean Squared Error 均方根误差)
def get_rmse(Y_parameters,Y_pred):
sum_mean=0
for i in range(len(Y_pred)):
sum_mean+=(Y_pred[i]-Y_parameters.values[i])**2
sum_error=np.sqrt(sum_mean/len(Y_pred))
return sum_error
if __name__ == '__main__':
print "Job Begins"
file_name="./Advertising.csv"
# 可以按照需要使用需要的列,增删列名即可
feature_cols = ['TV', 'Radio', 'Newspaper']
X,Y = get_data(file_name,feature_cols)
regr = linear_model.LinearRegression()
#数据的前195条用于训练,最后5条用于测试
model = regr.fit(X.head(195), Y.head(195))
#自己创造数据集的话,用下面的形式
# X_test = pd.DataFrame({'TV':[1],'Radio':[1],'Newspaper':[1]})
X_test = X.tail(5)
Y_real = Y.tail(5)
Y_pred = regr.predict(X_test)
# show_comparison(Y_real,Y_pred)
#通过改变features后的均方差大小,可以发现只使用TV和Radio这两个feature效果最好
print get_rmse(Y_real, Y_pred)
线性回归要求数据集必须是线性相关的,也是最基础的算法。比如要预测股票价格,就有点说不通了。算法的验证是一个比较复杂的过程,例子里就只做了一次。正常情况应该使用“十折验证法”。方法就是将数据分成十份,每次用一份当验证集合,其他九份当训练集合,最后看综合表现结果。
顺便再提两个名词,回归和过拟合。
回归的意思是将参数代入目标函数,并获得结果,也就是通常讲的预测。
过拟合的意思是训练之后的模型只对于训练数据有较好的表现,对于测试的数据表现不佳。