回归算法
线性回归
线性回归的定义
线性回归通过一个或多个自变量与因变量之间的关系进行建模与回归分析,其特点为一个或多个回归系数的参数的线性组合。根据自变量的个数不同分为一元线性回归和多元线性回归。其具体表示公式如下:
线性回归误差的度量
损失函数,又称最小二乘法,其具体表示公式如下:
线性回归减小误差的方法
核心思想:找到最小损失对应的W值
方法一:正规方程
- 具体公式
- 注意点:需要进行标准化处理
- 缺点:当特征过于复杂时,求解速度太慢。
- 代码示例(预测波士顿房价)
from sklearn.datasets import load_boston
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
def mylinear():
"""
用LinearRegression方法预测波士顿房价
:return: None
"""
# 获取数据
lb = load_boston()
# 分割数据集
x_train, x_test, y_train, y_test = train_test_split(lb.data,lb.target,test_size=0.25)
# 标准化处理(注意特征值和目标值需要实例化两个API,因为维数不一样)
std_x = StandardScaler()
x_train = std_x.fit_transform(x_train)
x_test = std_x.transform(x_test)
std_y = StandardScaler()
y_train = std_y.fit_transform(y_train.reshape(-1,1))
y_test = std_y.transform(y_test.reshape(-1,1))
# 用估计器预测房价结果
lr = LinearRegression()
lr.fit(x_train,y_train)
print(lr.coef_)
predict_std = lr.predict(x_test)
y_predict = std_y.inverse_transform(predict_std.reshape(-1,1))
print("线性回归方程预测房价结果:", y_predict)
print("线性回归方程预测军方误差:", mean_squared_error(std_y.inverse_transform(y_test),y_predict))
return None
if __name__ == '__main__':
mylinear()
输出结果:
[[-0.07189689 0.11056253 0.00852308 0.06510871 -0.17156743 0.3417623
-0.0550457 -0.29720397 0.22697546 -0.21461555 -0.21549207 0.12524146
-0.323476 ]]
线性回归方程预测房价结果: [[16.67298491]
[21.51430668]
[15.63161012]
[41.67971428]
[22.12070811]
[29.74143583]
[45.16135176]
[13.47566068]
[18.94535531]
[28.80047496]
[21.2140528 ]
[28.17246202]
[26.24308882]
[12.27271099]
[26.33784283]
[20.0184693 ]
[15.56019304]
[19.78458139]
[ 8.44834886]
[19.2649333 ]
[32.51179258]
[23.04744077]
[12.19437145]
[18.24760861]
[18.15170883]
[11.03283082]
[25.74066679]
[30.53326076]
[28.75518113]
[15.41794206]
[31.71846201]
[13.11025356]
[ 9.39360885]
[25.86065388]
[14.83219011]
[19.17522972]
[24.72453426]
[17.97900083]
[24.60920764]
[16.33075212]
[32.92539655]
[19.33175092]
[22.56207634]
[22.08126759]
[26.8019178 ]
[27.81518837]
[ 6.13935003]
[20.20341886]
[15.83163726]
[33.39822915]
[21.91187973]
[21.30148556]
[29.69154311]
[35.27221326]
[25.36056497]
[