前言
几次接触机器学习的第一部分(像我背单词只记得abandoned一样),都会被线性模型中直接求解的这个正规方程(Normal Equation)搞得一头雾水,梯度下降还好理解些,但这个正规方程是真的一点头绪没有,西瓜书的周老师和统计学习方法的李老师都是传统的“抽象大师”,愚笨的我完全看不懂啊,在网上找到的博客也都是直接矩阵求导得到的,知其然不知其所以然。直到有一天突然遇到一个奇怪的老教师,仅仅用了不到一个小时就给我讲明白了,特来记录一下,也借此感谢这位奇怪但不失幽默的大师。
问题重述
我们有必要再回顾一下线性模型是解决什么问题:
问题的大意就是:如果给定某些确定的点,能否找到一个确定的线(hypothesis),把点连起来,使得这条线能过经过尽可能多的点。(以机器学习目标的角度来看,就是能否找到一个假设可以有更好的泛化性,对未知的x能预测出较为准确的y)
方法
当然,我们熟悉的就是最小二乘法,指定loss函数,然后使用梯度下降的方法,一次次更新参数,这个方法在吴恩达老师的视频里讲述的非常形象,这里不在赘述,主要想说一说另一种比较简单粗暴的“正规方程”做法。
这里我想先把正规方程放在这里,让大家有个印象,然后我们一步一步把它推出来:
Θ = ( X T X ) − 1 X T y \Theta = (X^TX)^{-1} X^Ty Θ=(XTX)−1XTy
正规方程
线性方程组
我们不妨换一个角度思考这个问题,如果这些点本来就在一条直线上呢? 那这个问题就和解多元线性方程组没有任何差别了,每一个点都是一个方程,我们很容易求解出一个 X X X满足所有方程,而这个 X X X 在一维上,就是我们题设里面要求直线(hypothesis)的两个参数 w w w 和 b b b,同时呢,如果扩展到多维上,也不过是 X X X 变成多维向量,也一一对应着线性模型中的参数 w w w
问题所在
所以问题在哪呢?问题就是这些点不在一条直线上啊! 举一个简单的例子,考虑下面这三个点 ( 1 , 1 ) (1,1) (1,1) ( 2 , 2 ) (2,2) (2,2) ( 3 , 2 ) (3,2) (3,2):
假设我们要求解的直线为 y = b + w x y = b + wx y=b+wx它们所对应的线性方程组 A x = b Ax =b Ax=b(注意这里的b对应直线里面的y)是:
{ b + w = 1 b + 2 w = 2 b + 3 w = 2 \begin{cases} b+w=1\\ b+2w = 2\\ b+3w=2\\ \end{cases} ⎩⎪⎨⎪⎧b+w=1b+2w=2b+3w=2
这里矩阵形式为:
A ∗ x = b [ 1 1 1 2 1 3 ] ∗ [ w b ] = [ 1 2 2 ] \begin{matrix} & A & * & x & = & b \\ & \begin{bmatrix}1&1\\ 1&2\\ 1&3\\\end{bmatrix} & * & \begin{bmatrix}w\\b\\\end{bmatrix} & = & \begin{bmatrix}1\\2\\2\\\end{bmatrix} \\ \end{matrix} A⎣⎡111123⎦⎤∗∗x[wb]==