流程
回归问题的流程可以使用如下Demo进行描述:
一般使用最小二乘法作为损失函数:
回归的最终目的是使得损失函数最小,来达到预测的准确率。一般使用梯度下降法来求最小loss。
使用均方误差来进行回归性能的评估:
代码
from sklearn.datasets import load_boston
from sklearn.linear_model import SGDRegressor
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
def myLinear():
'''
使用线性回归预测房子价格
:return:
'''
#获取数据
lb = load_boston()
#print(lb) #包括特征和目标值的数据
#分个训练集与测试集
x_train,x_test,y_train,y_test = train_test_split(lb.data,lb.target,test_size=0.25)
#进行标准化处理
#特征值和目标值都需要进行归一化处理
std_x = StandardScaler()
x_train = std_x.fit_transform(x_train)
x_test = std_x.fit_transform(x_test)
#对目标值进行归一化处理
std_y = StandardScaler()
y_train = std_y.fit_transform(y_train.reshape(-1,1))
y_test = std_y.fit_transform(y_test.reshape(-1,1))
#estimator预测(梯度下降法进行房价预测)
sgd = SGDRegressor()
sgd.fit(x_train,y_train)
print(sgd.coef_) #打印训练参数
#预测测试集的房子价格
y_sgd_predict = std_y.inverse_transform(sgd.predict(x_test))
print('测试集每个房子的预测价格',y_sgd_predict)
print('均方误差:',mean_squared_error(std_y.inverse_transform(y_test),y_sgd_predict))
return None
if __name__ == '__main__':
myLinear()
输出如下:
参数w:[-0.06161057 0.06221613 -0.04479416 0.09231134 -0.11200841 0.30607084
0.00477479 -0.22025804 0.08570166 -0.03266436 -0.19150812 0.07555292
-0.38888804]
测试集每个房子的预测价格 [30.7357134 25.51974831 21.00425056 17.53829466 14.60163444 18.20009286
18.77399008 16.24578979 18.05538513 18.24586873 23.68445002 21.55312467
14.2357615 30.05169882 21.29156398 25.71810244 21.9950333 19.38126953
17.65906691 28.21374685 30.64545386 36.27230954 28.0784913 25.92380663
22.10676399 25.91674478 25.7805641 24.84206141 14.95562282 15.69368713
10.90358757 19.93573849 23.09866024 31.89905785 40.91873019 20.85018915
24.97985589 28.1411417 21.62299662 34.6012922 27.14617434 18.2070012
26.57700721 33.55133097 31.68044403 31.2399297 26.53571731 19.93378672
5.19917507 25.81266873 11.16420083 25.31760464 16.7161709 9.83216067
18.67996126 20.71075185 13.45801083 18.79893355 31.06572454 20.8723132
29.63527381 20.71092781 14.82212843 19.56480437 25.08577622 23.654464
23.34284048 18.53246061 23.78033433 34.81796526 26.63245778 27.12908949
24.16441436 19.76520137 24.71149459 35.51630278 24.29390015 37.90608215
12.71859071 17.66553579 24.97056766 33.07690784 13.57167345 19.64617293
33.67939841 22.40476151 27.52738195 21.36951862 4.46465718 18.47605145
23.00769744 33.29154626 17.05723082 23.53042879 6.28904547 17.37455041
14.19539973 20.44018977 20.09930415 23.13221996 29.34438395 25.43009535
24.02860846 18.03260796 13.85634042 35.67607996 21.9293058 20.67176952
28.31205854 35.22251859 27.54608516 10.24048281 23.6108069 30.23146429
23.16196823 14.64628754 27.82333924 25.92903784 14.2912243 20.91615769
22.820459 24.95822803 26.3137871 18.18402887 38.31359149 18.67107512
18.74023516]
均方误差: 21.328362012748162