机器学习算法可以大致分为监督学习和无监督学习两类,而回归和分类是监控学习的代表;分类用于离散型分布预测,而回归属于连续型分布预测,针对数值型的样本,回归的目的就是建立一个回归方程用来预测目标值;本次主要讲的是利用sklearn模块建立线性回归模型进行预测目标值。
(1)数据准备与预处理:
from sklearn import datasetsimport numpy as npfrom sklearn.model_selection import train_test_splitfrom sklearn.linear_model import LinearRegressiondiabets = datasets.load_diabetes()x = diabets.data[:,np.newaxis,2]#data共有6列数据,这里取得是第三列y = diabets.targetx_train, x_test, y_train, y_test = train_test_split(x, y)
np.newaxis()将x数据增加一个轴,转化为n列1行的矩阵;train_test_split()用来划分训练集和测试集,可以通过test_size和train_size来设定测试集和训练集的比例,默认情况是测试集是数据的25%;
(2)建模与预测:
regression = LinearRegression()regression.fit(x_train, y_train)diab_y_pred = regression.predict(x_test)
LinearRegression是一个类估计器,fit,predict是每一个类估计器
都有的方法;还有一个score方法来评估模型
(3)模型评估
regression.score(x_test, y_test)'''这里的score是值R方值,一元线性回归方程中的R方等于皮尔逊积矩相关系数,值处于0-1之间,这里的输出值为0.2916'''import matplotlib.pyplot as pltx_valid = diabets.data[:200,2]xx_valid = x_valid[:,np.newaxis]y_valid = diabets.target[:200]plt.scatter(x_valid, y_valid)plt.plot(xx_valid, regression.predict(xx_valid),'r-')plt.show()
(4)多项式回归:
解释变量还是只有一个,但是增加了指数项,这里取得是3
from sklearn.preprocessing import PolynomialFeaturescubic_featurizer = PolynomialFeatures(degree=3)x_train_cubic = cubic_featurizer.fit_transform(x_train)x_test_cubic = cubic_featurizer.transform(x_test)xx = np.linspace(-0.1, 0.1,100)xx_ = xx[:, np.newaxis]xx_valid_cubic = cubic_featurizer.transform(xx_)reg_cubic = LinearRegression()reg_cubic.fit(x_train_quadratic, y_train)plt.plot(xx,regression.predict(xx_),'r--')plt.plot(xx, reg_cubic.predict(xx_valid_cubic),'b-')plt.show()
PolynomialFeatures转化器可以将一个解释变量解析为多项,来拟合模型;
(5)多元线性回归:
这里我们设目标值受到x里的6个元素的影响,所以将x取数据的6个所有列;
x = diabets.datay = diabets.targetx_train, x_test, y_train, y_test = train_test_split(x, y)regression = LinearRegression()regression.fit(x_train, y_train)y_pred = regression.predict(x_test)regression.score(x_test, y_test)#score输出值为0.4558,模型拟合比单变量好import matplotlib.pyplot as pltplt.plot(range(len(y_test)), y_test, color="blue")plt.plot(range(len(y_test)), y_pred,color="red")plt.show()