Andrew Zhang
Nov 22, 2016
线性回归一个很不好的地方就是它的过拟合问题。对此不无法获取更多数据的时候,也有很多解决方法,比如说添加L1正则项的套索回归,添加L2正则项的岭回归,以及今天要说的earlystopping算法。
earlystopping算法基于重采样技术。对每次重采样后得到的样本集合划分为训练集和验证集,在训练集上训练一个回归模型,然后再验证集上测量验证集的准确率,如果在训练集和验证集上的准确率都不再提高,模型训练结束。
接下来,来看看我写的Earlystopping算法的代码。
首先是将数据集划分为80%的训练集和20%stoppingset(验证集)的代码
function [h,b,iterationNum]=EarlyStoppingParameterEstimate(y,x,maxIterationNum)
sampleTotalNum=length(y);
index=randperm(sampleTotalNum);
stoppingSetNum=floor(sampleTotalNum*0.2);
stoppingSetIndex=index(1:stoppingSetNum);
trainningSetIndex=index(stoppingSetNum+1:end);
y_stoppingSet=y(stoppingSetIndex,:);
x_stoppingSet=x(stoppingSetIndex,:);
y_trainningSet=y(trainningSetIndex,:);
x_trainningSet=x(trainningSetIndex,:);
y_trainningSet_mu=mean(y_trainningSet);
x_trainningSet_mu=mean(x_trainningSet);
x_trainningSet_sigma=std(x_trainningSet);
x_trainning_normalized=x_trainningSet-repmat(x_trainningSet_mu,size(x_trainningSet,1),1);
y_trainning_normalized=y_trainningSet-y_trainningSet_mu;
parameterNum=size(x,2);
h=zeros(parameterNum,1);
g=zeros(parameterNum,1);
alpha=0.9;
yibs=0.0001;
error_trainningSet=inf;
error_stoppingSet=inf;
flag=true; % 随着梯度下降,当前迭代步骤的前几步,偏差是否是在降低的过程中
error_trainningset=[];
error_stoppingset=[];
for iterationNum=1:maxIterationNum
g=x_trainning_normalized'*(x_trainning_normalized*h-y_trainning_normalized)+alpha*g;
h=h-yibs*g;
current_error_trainningSet=(y_trainning_normalized-x_trainning_normalized*h)'*(y_trainning_normalized-x_trainning_normalized*h);
error_trainningset=[error_trainningset current_error_trainningSet];
if(current_error_trainningSet>error_trainningSet)
break;
else
error_trainningSet=current_error_trainningSet;
end
temp_h=h./(x_trainningSet_sigma');
temp_b=y_trainningSet_mu-x_trainningSet_mu*temp_h;
y_stoppingSet_predict=x_stoppingSet*temp_h+temp_b;
current_error_stoppingSet=(y_stoppingSet-y_stoppingSet_predict)'*(y_stoppingSet-y_stoppingSet_predict);
error_stoppingset=[error_stoppingset current_error_stoppingSet];
if(current_error_stoppingSet>error_stoppingSet)
break;
else
error_stoppingSet=current_error_stoppingSet;
end
end
% figure
% hold on;
% error_trainningset=error_trainningset/size(x_trainningSet,1);
% error_stoppingset=error_stoppingset/stoppingSetNum;
% plot(error_trainningset,'g');
% plot(error_stoppingset,'r');
% legend('trainningSet','stoppingSet')
h=h./(x_trainningSet_sigma)';
b=y_trainningSet_mu-x_trainningSet_mu*h;
接下来是结合BootStrap的Earlystopping的整体回归算法
function [sigma,w,b]=EarlyStoppingRegression(y,x)
picNum=length(y);
for i=picNum:-1:1
if isnan(y(i))
y(i)=[];
x(i,:)=[];
end
end
[~,~,iterationNum]=EarlyStoppingParameterEstimate(y,x,200);
ValidatePicNum=length(y);
featureNum=size(x,2);
final_w=zeros(featureNum,1);
final_b=0;
cx=zeros(ValidatePicNum,featureNum);
cy=zeros(ValidatePicNum,1);
bootstrapNum=100;
for bootstrap=1:bootstrapNum
for i=1:ValidatePicNum
id=ceil(rand*ValidatePicNum);
if id>ValidatePicNum
id=ValidatePicNum;
end
if(id==0)
id=1;
end
cx(i,:)=x(id,:);
cy(i,:)=y(id,:);
end
[cw,cb,~]=EarlyStoppingParameterEstimate(cy,cx,iterationNum);
final_w=final_w+cw;
final_b=final_b+cb;
end
w=final_w/bootstrapNum;
b=final_b/bootstrapNum;
y_predict=x*w+b;
sigma=corr(y_predict,y);
岭回归
http://blog.csdn.net/zhangzhengyi03539/article/details/50042821
套索回归
http://blog.csdn.net/zhangzhengyi03539/article/details/50042951