线性回归几乎是所有机器学习的入门课程,但是由于符号定义表达方式不同,造成了很多人在入门时期感觉跟多向量非常矛盾。比如所行向量还是列向量, x i \textbf{x}_i xi以及 x j \textbf{x}_j xj究竟是行还是列等等,本篇将先介绍向量以及列表相关的例子,然后再介绍线性回归的内容。
入门解惑
对于大多数教程而言,一份统计表的形式往往如下所示:
示例1:
示例1为默认格式,也是大多数博客或文章采用的格式。
属性1 | 属性2 | 属性3 | 属性4 | … | |
---|---|---|---|---|---|
元组1 | |||||
元组2 | |||||
元组3 | |||||
元组i | … | … | … | … | … |
采用向量描述:
属性1 | 属性2 | 属性3 | 属性4 | … | |
---|---|---|---|---|---|
元组i | x i 1 x_{i1} xi1 | x i 2 x_{i2} xi2 | x i 3 x_{i3} xi3 | x i 4 x_{i4} xi4 | … |
x
i
\textbf{x}_i
xi代表行,但是并不能单纯地认为代表行的就是行向量,实际上大多数书籍或博客中默认都是列向量,如果有定义最好看清楚定义。
形如
x
i
=
(
x
i
1
,
x
i
2
,
x
i
3
,
x
i
4
,
.
.
.
x
i
n
)
T
\textbf{x}_i=(x_{i1},x_{i2},x_{i3},x_{i4},...x_{in})^T
xi=(xi1,xi2,xi3,xi4,...xin)T是列向量
形如
x
i
=
(
x
i
1
,
x
i
2
,
x
i
3
,
x
i
4
,
.
.
.
x
i
n
)
\textbf{x}_i=(x_{i1},x_{i2},x_{i3},x_{i4},...x_{in})
xi=(xi1,xi2,xi3,xi4,...xin)是行向量
属性j | … | |
---|---|---|
元组1 | x 1 j x_{1j} x1j | |
元组2 | x 2 j x_{2j} x2j | |
元组3 | x 3 j x_{3j} x3j | |
元组i | x i j x_{ij} xij | … |
x
j
\textbf{x}_j
xj代表列,单纯地看表格实际上无法判断是否是列还是行向量,同样大多数默认是列向量,具体需要看定义。
形如
x
j
=
(
x
1
j
,
x
2
j
,
x
3
j
,
x
4
j
,
.
.
.
x
n
j
)
T
\textbf{x}_j=(x_{1j},x_{2j},x_{3j},x_{4j},...x_{nj})^T
xj=(x1j,x2j,x3j,x4j,...xnj)T是列向量
形如
x
j
=
(
x
1
j
,
x
2
j
,
x
3
j
,
x
4
j
,
.
.
.
x
n
j
)
T
\textbf{x}_j=(x_{1j},x_{2j},x_{3j},x_{4j},...x_{nj})^T
xj=(x1j,x2j,x3j,x4j,...xnj)T是行向量
其对应的具体实例:
编号 | 年龄 | 性别 | 身高 | 体重 | 电话 |
---|---|---|---|---|---|
1 | 18 | 男 | 180 | 80 | 18938298162 |
2 | 17 | 男 | 180 | 80 | 18938298152 |
3 | 15 | 男 | 180 | 80 | 18938298142 |
4 | 16 | 男 | 180 | 80 | 18938298132 |
5 | 14 | 男 | 180 | 80 | 18938294122 |
定义
x
i
=
(
x
i
1
,
x
i
2
,
x
i
3
,
x
i
4
,
.
.
.
x
i
n
)
T
\textbf{x}_i=(x_{i1},x_{i2},x_{i3},x_{i4},...x_{in})^T
xi=(xi1,xi2,xi3,xi4,...xin)T
那么第一行可以表示为
x
1
=
(
1
,
18
,
男
,
180
,
80
,
18938298162
)
T
\textbf{x}_1=(1,18,男,180,80,18938298162)^T
x1=(1,18,男,180,80,18938298162)T
定义
x
j
=
(
x
1
j
,
x
2
j
,
x
3
j
,
x
4
j
,
.
.
.
x
n
j
)
T
\textbf{x}_j=(x_{1j},x_{2j},x_{3j},x_{4j},...x_{nj})^T
xj=(x1j,x2j,x3j,x4j,...xnj)T
那么第一列可以表示为
x
j
=
(
1
,
2
,
3
,
4
,
5
)
T
\textbf{x}_j=(1,2,3,4,5)^T
xj=(1,2,3,4,5)T
示例2.1
但是有些表格不按照上面的通用格式,例如
元组1 | 元组2 | 元组3 | 元组4 | … | |
---|---|---|---|---|---|
属性1 | |||||
属性2 | |||||
属性3 | |||||
属性4 |
当看到这种形式的表时,就需要警惕
x
i
\textbf{x}_i
xi与
x
j
\textbf{x}_j
xj究竟是代表行还是列。
一般默认
x
i
\textbf{x}_i
xi表示行,默认是列向量,具体看定义甚至文章语义。
元组1 | 元组2 | 元组3 | 元组4 | … | |
---|---|---|---|---|---|
属性i | x i 1 {x}_{i1} xi1 | x i 2 {x}_{i2} xi2 | x i 3 {x}_{i3} xi3 | x i 4 {x}_{i4} xi4 | … |
一般情况下默认 x j \textbf{x}_j xj仍然表示列,默认是列向量,具体看定义甚至文章语义。
元组j | … | |
---|---|---|
属性1 | x 1 j x_{1j} x1j | |
属性2 | x 2 j x_{2j} x2j | |
属性3 | x 3 j x_{3j} x3j | |
属性4 | x 4 j x_{4j} x4j |
示例2.2
有时也会出现例外,仍然是示例2.1中的表格,如下:(个人认为可能作者本人找到的示例表格形式是示例2.1的形式,但是理论知识却按照示例1的格式,作者只是想稍微偷懒不改了。)
此时
x
j
\textbf{x}_j
xj表示行,默认是列向量,具体是什么向量看定义甚至文章语义。
元组1 | 元组2 | 元组3 | 元组4 | … | |
---|---|---|---|---|---|
属性j | x 1 j {x}_{1j} x1j | x 2 j {x}_{2j} x2j | x 3 j {x}_{3j} x3j | x 4 j {x}_{4j} x4j | … |
x i \textbf{x}_i xi表示列,默认是列向量,具体是什么向量看定义甚至文章语义。
元组i | … | |
---|---|---|
属性1 | x i 1 x_{i1} xi1 | |
属性2 | x i 2 x_{i2} xi2 | |
属性3 | x i 3 x_{i3} xi3 | |
属性4 | x i 4 x_{i4} xi4 |
实际上遇到这种情况时可以将表格转置成示例一的形式。
正题开始–Linear regression
线性回归实际上是一种拟合方式,在现实应用中如果明确知道数据符合线性关系,那么直接使用即可;但是当不知道数据的关系时,如果是二维或三维数据,将离散数据绘制出来,观察数据之间的关系大致上符合线性关系也可以应用,如果是高维数据,线性回归可以作为一个数学模型参考使用,最终需要分析误差来决定是否采用这种模型。
符号定义
为了方便可视化,先处理二维数据看看实验效果,如下:
属性x | 结果y |
---|---|
1 | 7 |
2 | 10 |
3 | 12 |
4 | 16 |
5 | 18 |
6 | 23 |
x
i
=
(
x
i
1
,
x
i
2
,
x
i
3
,
x
i
4
,
.
.
.
,
x
i
j
,
.
.
.
x
i
n
)
T
\textbf{x}_i=(x_{i1},x_{i2},x_{i3},x_{i4},...,x_{ij},...x_{in})^T
xi=(xi1,xi2,xi3,xi4,...,xij,...xin)T是列向量,在表格中表示属性行,不包括结果。
x
j
=
(
x
1
j
,
x
2
j
,
x
3
j
,
x
4
j
,
.
.
.
,
x
i
j
,
.
.
.
x
m
j
)
T
\textbf{x}_j=(x_{1j},x_{2j},x_{3j},x_{4j},...,x_{ij},...x_{mj})^T
xj=(x1j,x2j,x3j,x4j,...,xij,...xmj)T是列向量,在表格中表示列。
x
=
(
x
1
T
;
x
2
T
;
x
3
T
;
.
.
.
;
x
m
T
)
\textbf{x}=(\textbf{x}_1^T;\textbf{x}_2^T;\textbf{x}_3^T;...;\textbf{x}_m^T)
x=(x1T;x2T;x3T;...;xmT)
x
i
j
x_{ij}
xij表示第i行第j列属性,对应表中的属性值。
y
^
=
(
y
^
1
,
y
^
2
,
.
.
.
,
y
^
i
,
.
.
.
y
^
m
)
T
\hat\textbf{y}=(\hat{y}_1,\hat{y}_2,...,\hat{y}_i,...\hat{y}_m)^T
y^=(y^1,y^2,...,y^i,...y^m)T是列向量,对应表格中的结果。
y
=
(
y
1
,
y
2
,
.
.
.
,
y
i
,
.
.
.
y
m
)
T
\textbf{y}=(y_1,y_2,...,y_i,...y_m)^T
y=(y1,y2,...,yi,...ym)T是列向量,对应拟合的结果。
w
=
(
w
1
,
w
2
,
.
.
.
w
j
,
.
.
.
,
w
n
)
T
\textbf{w}=(w_1,w_2,...w_j,...,w_n)^T
w=(w1,w2,...wj,...,wn)T是列向量,表示回归系数。
b是常数项。
推导过程
线性回归公式
y
i
=
∑
j
=
1
n
w
j
x
i
j
+
b
=
x
i
T
w
+
b
{y}_i=\sum_{j=1}^{n}w_jx_{ij}+b=\textbf{x}_i^T\textbf{w}+b
yi=j=1∑nwjxij+b=xiTw+b
为了方便矩阵表示,可以重定义向量
w
\textbf{w}
w以及向量
x
i
\textbf{x}_i
xi
w
=
(
b
,
w
1
,
w
2
,
.
.
.
,
w
j
,
.
.
.
,
w
n
)
T
\textbf{w}=(b,w_1,w_2,...,w_j,...,w_n)^T
w=(b,w1,w2,...,wj,...,wn)T
x
i
=
(
1
,
x
i
1
,
x
i
2
,
x
i
3
,
x
i
4
,
.
.
.
,
x
i
j
,
.
.
.
,
x
i
n
)
T
\textbf{x}_i=(1,x_{i1},x_{i2},x_{i3},x_{i4},...,x_{ij},...,x_{in})^T
xi=(1,xi1,xi2,xi3,xi4,...,xij,...,xin)T
重定义之后
y
i
=
x
i
T
w
y_i=\textbf{x}_i^T\textbf{w}
yi=xiTw
y
=
xw
(1)
\textbf{y}=\textbf{x}\textbf{w}\tag{1}
y=xw(1)
为了求出系数
w
\textbf{w}
w,求解策略是使拟合结果与实际结果误差最小。
误差公式表示为
e
=
∑
i
=
1
m
(
y
i
−
y
^
i
)
2
=
(
y
−
y
^
)
T
(
y
−
y
^
)
=
(
xw
−
y
^
)
T
(
xw
−
y
^
)
(2)
e=\sum_{i=1}^{m}(y_i-\hat{y}_i)^2=(\textbf{y}-\hat\textbf{y})^T(\textbf{y}-\hat\textbf{y})=(\textbf{x}\textbf{w}-\hat\textbf{y})^T(\textbf{x}\textbf{w}-\hat\textbf{y})\tag{2}
e=i=1∑m(yi−y^i)2=(y−y^)T(y−y^)=(xw−y^)T(xw−y^)(2)
求解目标
arg
min
w
e
(3)
\arg \space \min\limits_{\textbf{w}}{e}\tag{3}
arg wmine(3)
可以证明e是关于
w
\textbf{w}
w的一个凸函数,e存在最小值且为极小值,那么e的极小值点(导数为0)所在位置就是向量
w
\textbf{w}
w的取值,接下来需要对向量求导(具体方法见我的另一篇博文),
d
e
d
w
=
d
(
(
xw
−
y
^
)
T
(
xw
−
y
^
)
)
w
\frac{de}{d\textbf{w}}=\frac{d((\textbf{x}\textbf{w}-\hat\textbf{y})^T(\textbf{x}\textbf{w}-\hat\textbf{y}))}{\textbf{w}}
dwde=wd((xw−y^)T(xw−y^))
复合函数求导,令向量
u
=
(
xw
−
y
^
)
\textbf{u}=(\textbf{x}\textbf{w}-\hat\textbf{y})
u=(xw−y^)
d
e
d
w
=
d
u
d
w
d
u
T
u
d
u
=
x
T
2
(
xw
−
y
^
)
=
2
(
x
T
xw
−
x
T
y
^
)
(4)
\frac{de}{d\textbf{w}}=\frac{d\textbf{u}}{d\textbf{w}}\frac{d\textbf{u}^T\textbf{u}}{d\textbf{u}}=\textbf{x}^T2(\textbf{x}\textbf{w}-\hat\textbf{y})=2(\textbf{x}^T\textbf{x}\textbf{w}-\textbf{x}^T\hat\textbf{y})\tag{4}
dwde=dwdududuTu=xT2(xw−y^)=2(xTxw−xTy^)(4)
最终的求解结果是个列向量且
d
e
d
w
=
0
=
(
0
,
0
,
0
,
.
.
.
,
0
)
T
\frac{de}{d\textbf{w}}=\textbf{0}=(0,0,0,...,0)^T
dwde=0=(0,0,0,...,0)T
很容易看出
w
=
(
x
T
x
)
−
1
x
T
y
^
\textbf{w}=(\textbf{x}^T\textbf{x})^{-1}\textbf{x}^T\hat\textbf{y}
w=(xTx)−1xTy^,当且仅当
x
T
x
\textbf{x}^T\textbf{x}
xTx的逆存在时成立。
此时有些人可能会疑惑
w
=
x
−
1
y
^
\textbf{w}=\textbf{x}^{-1}\hat\textbf{y}
w=x−1y^时也成立,但是实际情况是
x
\textbf{x}
x常常不是方阵,没有逆。
常用解法
1.牛顿法
牛顿法的迭代公式为
w
t
+
1
=
w
t
−
▽
2
f
(
w
t
)
−
1
▽
f
(
w
t
)
\textbf{w}_{t+1}=\textbf{w}_t-\triangledown{^2f(\textbf{w}_t)}^{-1}\triangledown{f(\textbf{w}_t)}
wt+1=wt−▽2f(wt)−1▽f(wt)
由(4)得
▽
f
(
w
)
=
2
(
x
T
xw
−
x
T
y
^
)
\triangledown{f(\textbf{w})}=2(\textbf{x}^T\textbf{x}\textbf{w}-\textbf{x}^T\hat\textbf{y})
▽f(w)=2(xTxw−xTy^)
求
▽
2
f
(
w
)
=
2
x
T
xE
\triangledown{^2f(\textbf{w})}=2\textbf{x}^T\textbf{x}\textbf{E}
▽2f(w)=2xTxE
那么
▽
2
f
(
w
)
−
1
=
1
2
(
x
T
x
)
−
1
\triangledown{^2f(\textbf{w})}^{-1}=\frac{1}{2}(\textbf{x}^T\textbf{x})^{-1}
▽2f(w)−1=21(xTx)−1
那么
w
t
+
1
=
w
t
−
(
x
T
x
)
−
1
(
x
T
x
w
t
−
x
T
y
^
)
=
(
x
T
x
)
−
1
x
T
y
^
\textbf{w}_{t+1}=\textbf{w}_{t}-(\textbf{x}^T\textbf{x})^{-1}(\textbf{x}^T\textbf{x}\textbf{w}_t-\textbf{x}^T\hat\textbf{y})=(\textbf{x}^T\textbf{x})^{-1}\textbf{x}^T\hat\textbf{y}
wt+1=wt−(xTx)−1(xTxwt−xTy^)=(xTx)−1xTy^
对于本问题一步迭代即可得到最终结果,而且与上面看出的结果一致。
直接在控制台上运行
2.梯度下降法
梯度下降法迭代公式为
w
t
+
1
=
w
t
−
α
▽
f
(
w
t
)
\textbf{w}_{t+1}=\textbf{w}_t-\alpha\triangledown{f(\textbf{w}_t)}
wt+1=wt−α▽f(wt)
实验结果
可以看到梯度下降算法与牛顿法求取的结果近似相等。
绘制图像
python代码
import numpy as np
import matplotlib.pyplot as plt
#输入数据以及\alpha
#梯度下降
def desc(data,a):
row,col=data.shape
w1=np.zeros([col,1])
w=np.ones([col,1])
#合并数组
x=np.append(data[:,:-1],np.ones([row,1]),axis=1)
y=data[:,-1:]
while sum(abs(w1-w))[0]>0.001:
w=w1
dfw=2*(np.dot(np.dot(x.T,x),w)-np.dot(x.T,y))
w1=w-a*dfw
return w1
#针对二维平面绘制图像
def draw(data):
x=np.array(range(0,10))
y=w[0]*x+w[1]
plt.scatter(data[:,:-1],data[:,-1:])
plt.plot(x,y,color='red')
plt.show()
data=np.array([[1,2,3,4,5,6],[7,10,12,16,18,23]]).T
w=desc(data,0.01)
print(desc(data,0.01))
draw(data)