编写的函数如下:
function [ w,b ] = original_style( training_set,study_rate )
%training_set是一个m*n维矩阵,其中第一行是y_i,剩下的行的x_i
%选取初始值w_0,b_0
w=0;
b=0;
count=0; %每一次正确分类点个数
iteration_count=0; %迭代次数
fprintf('迭代次数\t误分类点\t\t权值w\t\t偏置b\t\n');%输出结果标题
while count ~= size(training_set,2)
count=0;
%在训练集中选取数据(x_i,y_i)
for i=1:size(training_set,2)
count = count+1;
%如果y_i(w*x_i+b)<=0,则对w和b进行相应的更新
if training_set(1,i)*(w'*training_set(2:size(training_set,1),i)+b)<=0
w = w + study_rate*training_set(1,i)*training_set(2:size(training_set,1),i);
b = b + study_rate*training_set(1,i);
iteration_count=iteration_count+1;
count=count-1;%不是正确分类点,减一
fprintf('\t%u\t',iteration_count);%输出迭代次数
fprintf('\t\t%u\t',i);%输出误分类点
fprintf('\t(%2.1g,%2.1g)''\t',w);%输出w
fprintf('\t%4.1g\n',b);%输出b
end
end
end
end
测试代码如下:
clear all;
training_set=[1,-1,1;3,1,4;3,1,3];
study_rate=1;
[w,b]=original_style( training_set,study_rate );
执行结果如下图: