导言:
由线性SVM出发,替换核函数,推测可以实现非线性分类,由此引申,可以实现其他功能。
1. 首先查看数据分布情况
scatter(rawdata(:,1),rawdata(:,2))
数据可被线性分类
2. 算法步骤
这里原谅我直接用了图片,打出来有点麻烦
3.程序解析
%采用启发式更新寻找最佳alpha
%启发式是指个体利用自身或者全局的经验来制定各自的搜索策略。此程序中对alpha用i遍历,随机配对j,共同进行优化,随机到了全局,故可以说是启发式
alpha=zeros(100,1);%定义一个初始化为O的alpha矩阵
C=0.6;%设置惩罚参数
b=0;
bi=0;
bj=0;
tolerance=0.001;%设置松弛系数
Ei=0;
Ej=0;
z=1;
%惩罚参数和松弛系数的设置关乎寻求最优模型,松弛系数是模型对离群点的容忍程度,也就是可以容忍离群点离群多远,当松弛系数设置为0时,就是硬间隔分类器。惩罚参数是模型对数据的重视程度,越大表明越不想丢掉这些点。
在设置这两个值之前,应该对数据有一个整体的把握,观察数据分布情况。比如,有的数据正样本太多而负样本少,或者负样本集中,这时分类可能产生偏移,需要去优化模型,可以在分类时设置负样本惩罚参数比正样本惩罚参数大,表示重视负样本。因此具体问题具体分析。
在对惩罚参数的选择上,一种思路是凭借经验进行寻优,另一种是通过算法进行寻优,就像SVM本身是寻找最小||w||,可以考虑将程序嵌入寻优算法,比如遗传算法之类。
while z<30%设置迭代次数
updatepairs=0;
for i=1:100%对数据集从第一个开始更新alpha
fxi=0;
fxj=0;
for k=1:100%计算由所有alpha推出的误差公式
fxi=fxi+alpha(k)*rawdata(k,3)*rawdata(i,1:2)*rawdata(k,1:2)';
end
Ei=fxi+b-rawdata(i,3);
if(rawdata(i,3)*Ei< -tolerance && alpha(i) < C) || (rawdata(i,3)*Ei> tolerance && alpha(i) > 0)%选择不满足kkt条件的alpha进行更新
j=floor(rand(1)*100)+1;%随机选择一个配对的alpha进行更新
while(i==j)
j=floor(rand(1)*100)+1;
end
for k=1:100
fxj=fxj+alpha(k)*rawdata(k,3)*rawdata(j,1:2)*rawdata(k,1:2)';
end
Ej=fxj+b-rawdata(j,3);
alphaiold=alpha(i);%保存了旧的alpha值
alphajold=alpha(j);
if rawdata(i,3)~=rawdata(j,3)%计算alpha的上下界
L=max(0, alphajold-alphaiold); %用的是旧的alpha值
H=min(C,C+ alphajold-alphaiold);
else
L=max(0, alphajold+alphaiold-C);
H=min(C, alphajold+alphaiold);
end
rate=rawdata(i,1:2)*rawdata(i,1:2)'+rawdata(j,1:2)*rawdata(j,1:2)'-2*rawdata(i,1:2)*rawdata(j,1:2)';%计算更新速率
alpha(j)=alphajold+rawdata(j,3)*(Ei-Ej)/rate;%更新alphaj
if alpha(j)>H%修剪alphaj
alpha(j)=H;
elseif alpha(j)<L
alpha(j)=L;
end
alpha(i)=alphaiold+rawdata(i,3)*rawdata(j,3)*(alphajold-alpha(j)); %更新alphai bi=b-Ei-rawdata(i,3)*(alpha(i)-alphaiold)*rawdata(i,1:2)*rawdata(i,1:2)'-rawdata(j,3)*(alpha(j)-alphajold)*rawdata(j,1:2)*rawdata(i,1:2)';
bj=b-Ej-rawdata(i,3)*(alpha(i)-alphaiold)*rawdata(i,1:2)*rawdata(j,1:2)'-rawdata(j,3)*(alpha(j)-alphajold)*rawdata(j,1:2)*rawdata(j,1:2)';
if alpha(i)>0&&alpha(i)<C%确定b值
b=bi;
elseif alpha(j)>0&&alpha(j)<C
b=bj;
else
b=(bi+bj)/2;
end
updatepairs=updatepairs+1;
else
continue;
end
end
if( updatepairs==0) %设置循环终止条件,即当所有的alpha都不再更新30次就代表已经寻找到了最优解
z=z+1;
else
z=0;
end
end
%找到确定直线的点
w=zeros(100,2);
%去掉不小于C的点
for i=1:100
if alpha(i)>=C
alpha(i)=0;
end
end
temp=alpha.*rawdata(:,3);
for i=1:length(alpha)
w(i,:)=temp(i).*rawdata(i,1:2);
end
w=sum(w);
x1=max(rawdata(:,1));y2=min(rawdata(:,2));
y1=(-w(1)*x1-b)/w(2);
x2=(-w(2)*y2-b)/w(1);
scatter(rawdata(:,1),rawdata(:,2));
grid on;
text(rawdata(alpha>0,1),rawdata(alpha>0,2),'o','color','r');
hold;
plot([x1,x2],[y1,y2]);
学习博客:https://cuijiahua.com/blog/2017/11/ml_8_svm_1.html
学习笔记: