%% 画数据散点图
%第214对数据有问题,先删除
Data = xlsread('D:\Matlab\test\数据集\train.csv');
x = Data(:,1);
y = Data(:,2);
%均值归一化,在这里的数据中不进行均值归一化会造成运算为 NaN 的结果
x = (x- mean(x))./ (max(x)-min(x));
y = (y- mean(y))./ (max(y)-min(y));
plot(x,y,'.');
hold on;
%% 参数初始化
m = length(x);%样本数量
theta = [1;0];%theta初始化
%预先分配空间以节省运行时间
X = [ones(m,1),x];%特征值的增广矩阵
%梯度下降法
pd = zeros(m,2);%J对theta的偏导矩阵
cost = zeros(m,1);
alpha = 0.1;
itration = 10000;
%% 梯度下降法迭代寻找最小值
for i = 1:itration
h = X*theta;
cost = (y-h).*(y-h);
pd(:,1) = (h-y).*X(:,1);
pd(:,2) = (h-y).*X(:,2);
theta(1) = theta(1) - alpha/m*sum(pd(:,1));
theta(2) = theta(2) - alpha/m*sum(pd(:,2));
J = 1/(2*m)*sum(cost);
end
%% 正规方程法
% theta = ((X'*X)^(-1))*X'*y;
% n较小时使用正规方程法简单且速度较快但要注意X'X的可逆性问题;
% n较大时使用梯度下降法运算更快。
%% 画线
X = min(x):0.01:max(x);
Y = theta(1)+theta(2)*X;
plot(X,Y,'LineWidth',2);
效果图:
相关知识点和公式可参考我的另一篇博客中的线性回归部分:
机器学习基础----基于吴恩达机器学习课程的笔记_m0_61112058的博客-CSDN博客
数据集: