首先,分享以下我学习李航老师的《统计学习方法》中感知机原始形式学习笔记,如有错误或者其他见解,恳请指正。
感知机的对偶形式请参考我的另一篇blog: 感知机对偶形式C++实现
感知机的原始形式如下:
下面直接上代码,此处我用的是C++代码用STL中的向量实现存储,当然也可以用数组或者其他方式,感知机的原始形式代码如下:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <vector>
using namespace std;
void input_init();
void train();
bool judge(int i);
void output();
///****w、b、步长和训练集********
vector<int> data[1000];
vector<int> w;
int b=0;
int len=1; ///步长
///****w、b、步长和训练集********
int n; //特征向量的维数
int N; //输入实例的个数
int k=0; //迭代次数
int main()
{
freopen("in.txt","r",stdin);
input_init();
train();
system("pause");
return 0;
}
void input_init(){
cout<<"请输入特征向量的维数:"<<endl;
cin>>n;
cout<<"请输入实例的个数:"<<endl;
cin>>N;
int tmp;
for(int i=0;i<N;i++){
cout<<"请输入第"<<i+1<<"个实例:"<<endl;
for(int j=0;j<=n;j++){
cin>>tmp;
data[i].push_back(tmp);
if(i==0&&j<n)
w.push_back(0);
}
}
cout<<endl;
cout<<"迭代次数|"<<"误分类点"<<" w\t"<<"b\t"<<"w.x+b"<<endl;
cout<<"0"<<"\t \t"<<" 0\t"<<"0\t"<<"0\t"<<endl;
}
void train(){
for(int i=0;i<N;i++)
if(!judge(i)){ //判断是否误分类
for(int j=0;j<n;j++)
w[j] += len*data[i][n]*data[i][j]; //w<-w+n*yi*xi
b += len*data[i][n]; //b<-b+n*yi
cout<<++k<<"\t"<<" x"<<i+1<<" ";
output();
i = -1; ///切记是-1不是0
}
cout<<++k<<"\t"<<" 0 ";
output();
}
bool judge(int i){
int sum = 0;
for(int j=0;j<n;j++)
sum += w[j]*data[i][j]; //w内积xi
return (data[i][n]*(sum+b)>0)? true:false; //yi(w.xi+b)<=0
}
void output(){
cout<<"\t";
for(int i=0;i<n;i++){
if(i==0) cout<<"("<<w[i];
else if(i==n-1)cout<<","<<w[i]<<")\t"<<b<<"\t";
else cout<<","<<w[i];
}
for(int i=0;i<n;i++){
if(w[i]>0){
if(i!=0) cout<<"+";
}
else if(w[i]<0)
cout<<"-";
else continue;
if(fabs(w[i])!=1.0)
cout<<abs(w[i]);
cout<<"x"<<i+1;
}
if(b!=0){
if(b>=0)
cout<<"+";
cout<<b;
}
cout<<endl;
}
此处我输入的文件in.txt
内容如下:
运行结果如下:
以上核心代码其实只有judge()
和train()
,去掉冗余输入输出后如下:
bool judge(int i){
int sum = 0;
for(int j=0;j<n;j++)
sum += w[j]*data[i][j]; //w内积xi
return (data[i][n]*(sum+b)>0)? true:false; //yi(w.xi+b)<=0
}
void train(){
for(int i=0;i<N;i++)
if(!judge(i)){ //判断是否误分类
for(int j=0;j<n;j++)
w[j] += len*data[i][n]*data[i][j]; //w<-w+n*yi*xi
b += len*data[i][n]; //b<-b+n*yi
i = -1; ///切记是-1不是0
}
}
归纳总结:
①判断误分类函数的for循环函数复杂度不同:
**对偶形式:**判断是否误分类函数judge()
for循环是:
for(int j=0;j<N;j++) //原始形式中是for(int j=0;j<n;j++)
if(a[j]==0) continue;
else sum += a[j]*data[j][n]*Gram[j][i]; //sum = ∑aj*yj*xj*xi
**原始形式:**判断是否误分类函数judge()
for循环是:
for(int j=0;j<n;j++)
sum += w[j]*data[i][j]; //w内积xi
由此可以看出对偶形式是[0,N-1],原始形式是[0,n-1],当n(特征向量维数)>>N(输入实例个数)
时,对偶形式循环复杂度较低,且对偶形式在训练数据之前先对训练数据求Gram矩阵,使得在此处本需要求xj*xi
内积时,直接调用Gram矩阵中的值即可,使得运算量降低。
②更新系数的方式不同:
**对偶形式:**训练数据的函数train()
更新系数的代码:
a[i] += len; //a<-a+n 此时步长len=1
b += len*data[i][n]; //b<-b+n*yi
**原始形式:**训练数据的函数train()
更新系数的代码:
for(int j=0;j<n;j++)
w[j] += len*data[i][n]*data[i][j]; //w<-w+n*yi*xi
b += len*data[i][n]; //b<-b+n*yi
由此可以看出对偶形式对αi
的更新是αi <- αi + len
;而原始形式对w
的更新是w <- w + len*yi*xi
,其中αi
是一个值,代表第i
个实例被误分类的次数(或称为被误分类而更新的次数),而w
是一个n
维的向量,所以更新会比αi
多需要一个for循环来执行。故此,更新复杂度也就比对偶形式高。