本文参考链接:https://www.cnblogs.com/youngsea/p/7368359.html。
进行了部分改进。
超平面方程:
w
→
⋅
x
→
+
b
=
0
\overrightarrow{w}\cdot \overrightarrow{x}+b=0
w⋅x+b=0
对二维的情况来说就是:
w
→
=
(
w
1
,
w
2
)
,
x
→
=
(
x
1
,
x
2
)
\overrightarrow{w}=(w_1,w_2),\overrightarrow{x}=(x_1,x_2)
w=(w1,w2),x=(x1,x2),则超平面方程为
w
1
x
1
+
w
2
x
2
+
b
=
0
w_1x_1+w_2x_2+b=0
w1x1+w2x2+b=0
⇒
x
2
=
−
1
w
2
(
w
1
x
1
+
b
)
\Rightarrow x_2=-\frac{1}{w_2}(w_1x_1+b)
⇒x2=−w21(w1x1+b)
是一条直线。
训练数据的文件,1.txt:
3.54 1.97 -1
3.01 2.55 -1
7.55 -1.58 1
2.11 0 -1
8.12 1.27 1
7.11 -0.98 1
8.61 2.05 1
2.32 0.26 -1
3.63 1.73 -1
0.34 -0.89 -1
3.12 0.29 -1
2.12 -0.78 -1
0.88 -2.79 -1
7.13 -2.32 1
1.69 -1.21 -1
8.11 0.63 1
8.49 -0.26 1
4.65 3.55 -1
8.19 1.55 1
1.20 0.21 -1
MATLAB脚本,显示迭代过程超平面( w → 和 b \overrightarrow{w}和b w和b)的动态变化:
clc;clear;
data=load('1.txt');
x=[data(:,1),data(:,2)];%x的第一列是x1,第二列是x2
y=data(:,3); %类别
l=length(y);
%% 根据数据的标签画出散点图
for j=1:l
if y(j)==1
plot(x(j,1),x(j,2),'o');
hold on
end
if y(j)==-1
plot(x(j,1),x(j,2),'x');
hold on
end
end
%% 初始化参数,对应算法第一步
w=[0,0]; %训练数据是二维的,故w也是二维的
b=0;
r=0.5; %学习率
con=0; %set the condition
t=0; %迭代次数
b_record=[]; %记录b的历史值
w_record=[]; %记录w的历史值
while con==0 %条件为训练集没有误分类点
for i=1:l
if (y(i)*(dot(w,x(i,:))+b))<=0 %若是误分类点,则小于0
w(1)=w(1)+r*y(i)*x(i,1); %误分类点更新w值
w(2)=w(2)+r*y(i)*x(i,2);
b=b+r*y(i);
w=[w(1),w(2)];
w_record=[w_record;w];
b_record=[b_record,b];
t=t+1;
end
end
for i=1:l
con1(i)=(y(i)*(dot(w,x(i,:))+b)); %con1存储更新后的w和b再再计算一次所有点的y(wx+b)的结果,
end
con=(all(con1(:)>0)); %如果con1全为正,则con为1 ;all(x):如果所有元素都非0,all(x)返回1,否则返回0.
end
%% 迭代完成后开始画图
for i=1:t
xt=-2:0.1:10; %画出分类超平面
yt=(-w_record(i,1)*xt-b_record(i))/w_record(i,2);
h=plot(xt,yt);
axis([-2 10 -10 6]);
hold on
pause(0.5); %延时0.5秒
if(i~=t)
delete(h);
end
end