理论部分请参照李航博士的统计学习方法一书
Point类表示需要分类的样本点
package com.czb.ganzhiji;
public class Point {
double x[]=new double[2];
double y;
public Point(){
}
public Point(double x[],double y){
this.x=x;
this.y=y;
}
}
/**
* 感知机对偶形式的代码
*/
package com.czb.ganzhiji;
import java.util.ArrayList;
import java.util.Arrays;
public class Ganzhiji2 {
private double w[];
private double b=0;
private double a[];
private double eta;
ArrayList<Point> arrayList;
public Ganzhiji2(ArrayList<Point> arrayList,double eta){
this.arrayList=arrayList;
w=new double[arrayList.get(0).x.length];
a=new double[arrayList.size()];
this.eta=eta;
}
public Ganzhiji2(ArrayList<Point> arrayList){
this.arrayList=arrayList;
w=new double[arrayList.get(0).x.length];
a=new double[arrayList.size()];
this.eta=1;
}
private double f(double x1[],double x2[]){//进行两个向量的内积计算
double sum=0;
for(int i=0;i<x1.length;i++){
sum=sum+x1[i]*x2[i];
}
return sum;
}
private double g(ArrayList<Point> arrayList,int m){//用来判断模型
double sum=0;
for(int i=0;i<arrayList.size();i++){
sum=sum+a[i]*arrayList.get(i).y*f(arrayList.get(i).x, arrayList.get(m).x);
}
return arrayList.get(m).y*(sum+b);
}
private void h(ArrayList<Point> arrayList,int m){//用来更新a和b
a[m]=a[m]+eta;
b=b+arrayList.get(m).y;
System.out.print(a[0]+" "+a[1]+" "+a[2]+" "+b);
System.out.println();
}
private void classify(){
boolean flag=false;
while(!flag){
for(int i=0;i<arrayList.size();i++){
if(g(arrayList, i)<=0){
h(arrayList, i);
break;
}
if(i==arrayList.size()-1){
flag=true;
}
}
}
for(int i=0;i<arrayList.size();i++){
double temp1=a[i]*arrayList.get(i).y;
for(int j=0;j<arrayList.get(0).x.length;j++){
if(j==0)
w[j]+=arrayList.get(i).x[j]*temp1;
else
w[j]+=arrayList.get(i).x[j]*temp1;
}
}
System.out.println(Arrays.toString(w));
System.out.println(b);
}
public static void main(String[] args) {
Point point1=new Point(new double[]{3, 3},1);
Point point2=new Point(new double[]{4, 3},1);
Point point3=new Point(new double[]{1, 1},-1);
ArrayList<Point> arrayList=new ArrayList<>();
arrayList.add(point1);
arrayList.add(point2);
arrayList.add(point3);
Ganzhiji2 ganzhiji2=new Ganzhiji2(arrayList);
ganzhiji2.classify();
}
}