多元线性回归
多元线性回归与一元线性回归类似,只是特征值由一个变为了两个及以上。
表达式:
h
θ
(
x
i
)
=
θ
0
+
θ
1
x
+
θ
2
x
2
+
.
.
.
+
θ
n
x
n
h_\theta(x_i)=\theta_0+\theta_1x+\theta_2x_2+...+\theta_nx_n
hθ(xi)=θ0+θ1x+θ2x2+...+θnxn
因此可用向量写成:
h
θ
(
x
i
)
=
θ
i
X
i
T
h_\theta(x_i)=\theta_iX_i^T
hθ(xi)=θiXiT
其中
X
T
=
(
X
0
,
X
1
,
X
2
.
.
.
.
,
X
n
)
X^T=(X_0,X_1,X_2....,X_n)
XT=(X0,X1,X2....,Xn)其中
X
0
X_0
X0恒为1。
而代价函数仍为:
(
真
实
值
−
预
测
值
)
2
(真实值-预测值)^2
(真实值−预测值)2 的平均·。
除了多元线性回归,还有多项式回归。多项式回归是因为用直线拟合不够准确,因此要用平滑的曲线拟合,如下图为多项式回归的一种情况:
多项式回归的一般式子可写成:
Y
i
=
β
0
+
β
1
X
i
+
β
2
X
2
+
.
.
.
+
β
k
X
i
k
Y_i=\beta_0+\beta_1X_i+\beta_2X_2+...+\beta_kX_i^k
Yi=β0+β1Xi+β2X2+...+βkXik
其中当k的值越大时,拟合的效果越好,曲线越平滑,但可能出现过拟合的情况。
标准方程法
除了梯度下降法,还有标准方程法也可用于求解参数。一般当参数较少时用标准方程法较为合适,其复杂度为O(
n
3
n^3
n3),其中n为k的大小,即特征量的个数。
已知代价函数为:
∑
i
=
1
m
(
h
w
(
x
i
)
−
y
i
)
2
=
(
y
−
X
w
)
T
(
y
−
X
w
)
\sum_{i=1}^{m}(h_w(x^i)-y^i)^2=(y-Xw)^T(y-Xw)
∑i=1m(hw(xi)−yi)2=(y−Xw)T(y−Xw)
例如下列一组数据:
x
=
[
x
0
x
1
x
2
x
3
x
4
1
2104
5
1
45
1
1416
3
2
40
1
1536
3
2
30
1
852
2
1
36
]
x= \left[ \begin{matrix} x_0&x_1&x_2&x_3&x_4\\ 1 & 2104 &5&1&45 \\ 1 & 1416 &3&2&40\\ 1 & 1536 &3&2&30 \\ 1 & 852 &2&1&36 \\ \end{matrix} \right]
x=⎣⎢⎢⎢⎢⎡x01111x1210414161536852x25332x31221x445403036⎦⎥⎥⎥⎥⎤
w
=
[
w
0
w
1
w
2
w
3
w
4
]
w= \left[ \begin{matrix} w_0\\ w_1\\ w_2\\ w_3\\ w_4\\ \end{matrix} \right]
w=⎣⎢⎢⎢⎢⎡w0w1w2w3w4⎦⎥⎥⎥⎥⎤
y
=
[
460
232
315
178
]
y= \left[ \begin{matrix} 460\\ 232\\ 315\\ 178\\ \end{matrix} \right]
y=⎣⎢⎢⎡460232315178⎦⎥⎥⎤
其中w为参数向量,对代价函数的w进行求偏导,即
∂
∂
w
[
(
y
−
X
w
)
T
(
y
−
X
w
)
]
\frac{\partial}{\partial w}[(y-Xw)^T(y-Xw)]
∂w∂[(y−Xw)T(y−Xw)]
根据矩阵求导法则可求得 w=
(
X
T
X
)
−
1
X
T
y
(X^TX)^{-1}X^Ty
(XTX)−1XTy
实例代码如下:
import numpy as np
from numpy import genfromtxt
import matplotlib.pyplot as plt
# 载入数据
data = np.genfromtxt("Salary_Data.csv",delimiter=",")
x_data = data[1:,0,np.newaxis]
y_data = data[1:,1,np.newaxis]
plt.scatter(x_data,y_data)
plt.show()
#为X矩阵 添加偏置项,即添加x0=1
X_data = np.concatenate((np.ones((30,1)),x_data),axis=1)
print(X_data)
#标准方程法求解回归参数
def weights(xArr, yArr):
xMat = np.mat(xArr)
yMat = np.mat(yArr)
xTx = xMat.T*xMat # 矩阵乘法
# 计算矩阵的值,如果值为0,说明矩阵不可逆
if np.linalg.det(xTx) == 0.0:
print("矩阵不可逆!")
return
ws=xTx.I*xMat.T*yMat
return ws
x_test = np.array([[0],[12]])
y_test = ws[0]+x_test*ws[1]
plt.plot(x_data,y_data,'b.')
plt.plot(x_test,y_test,'r')
plt.show()
其中salary_data.csv 为参考的工资数据。
最终结果如下: