线性回归是什么?
所谓线性回归(以单变量为例),说白了就是给你一堆点,你要从这一堆点中找出一条直线。如下图
本文截图均出自 Andrew Ng的<
机器学习公开课>
找到这条直线以后可以做什么? 我们假设我们找到了 代表这条直线的 a 和b,那么直线表达式为
y = a + b*x, 那么当 新来一个 x,我们就可以知道 y了。
Andrew Ng第一节课就说过,什么是机器学习?就是
A computer program is said to learn from experience E with
respect to some task T and some performance measure P if its
performance on T, as measured by P, improves with experience E.
这句话真心不好翻译:从经验E中学习如何完成任务T,并且用方法P来衡量T的好坏。通过经验E的学习,可以不断提高用P来衡量的任务T的表现。(轻拍)
OK,那么线性回归就是做这样一件事,给你一堆历史数据,你训练出一条直线(以单变量线性回归为例),那么再有 新的x输入时,通过这条直线的表达式,你就可以知道 y是多少啦。从而达到预测的目的。
怎么求这条直线?
看着上面那条绿线画的挺容易。其实,它背后代表的数学逻辑是:所有的点到这条直线的距离之和最小。 Andrew Ng称之为 cost function,见下图
本文截图均出自 Andrew Ng的<
机器学习公开课>
J 是一个关于 theta0 和 theta1的二次函数,二次函数求最小值当然是
最小二乘法,theta少的时候,还可以用手算,但始终不是长久之计。下面看 Gradient descent,梯度下降法,
给 theta 选定一个初始值(一般选为1),通过下面的公式更新 theta。其中alpha 大于0;
如果 theta的偏微分(斜率) > 0, 说明当前theta取值在最低点的右侧,那么通过下面的公式,theta就会左移(减小),只要alpha取值合理,theta就会靠近最低点。反之亦然
本文截图均出自 Andrew Ng的<
机器学习公开课>
theta都求出来了,那么这个直线也就可以表示出来了,算法描述如下:
重要的事情说三遍:本文截图均出自 Andrew Ng的<
机器学习公开课>
talk is easy, show me the code: 下面是matlab实现,非常简单,这里参考了Andrew Ng作业模版
function [theta, J_history] = gradientDescentMulti(X, y, theta, alpha, num_iters)
% Initialize some useful values
m = length(y); % number of training examples
J_history = zeros(num_iters, 1);
% 迭代次数是自己定的,具体多少还的看实验中的收敛情况
for iter = 1:num_iters
% Hint: While debugging, it can be useful to print out the values
% of the cost function (computeCostMulti) and gradient here.
%
h = X * theta; %毕竟X和theta都一致了,先计算h后面要用
error = h - y; %当前误差
theta = theta - alpha * (1/m * sum(error .* X)’); %带入迭代公式,alpha 右边那个就是偏微分的计算公式,sum都是指column sum,对所有的训练数据的每一个变量分别求和。向量error 和矩阵X的 .* 就是error[0] * X[, 0](第一列)
% ============================================================
% Save the cost J in every iteration
J_history(iter) = computeCostMulti(X, y, theta);
end
end
拟合出来的曲线大概是下面的样子
多说两句,为了确定 GD方法正确的运行,可以记录 每次迭代的 J值,画出如下曲线,下面的都是不正常的
总结:如果alpha太小,收敛的会比较慢; 如果alpha太大,J 值可能不会收敛
alpha如何取值? try … 0.001, 0.003, 0.01, 0.03, 0.1, 0.3 , 1… …
为了使GS算法运行的更舒服一点,还有 feature scaling 和 mean normalization等方法先处理数据,今天先到这儿