统计学习方法之感知机对偶形式Java实现代码

理论部分请参照李航博士的统计学习方法一书
Point类表示需要分类的样本点
package com.czb.ganzhiji;

public class Point {
	
	double x[]=new double[2];
	double y;
	
	public Point(){
		
	}
	
	public Point(double x[],double y){
		this.x=x;
		this.y=y;
	}

}

/**
 * 感知机对偶形式的代码
 */
package com.czb.ganzhiji;

import java.util.ArrayList;
import java.util.Arrays;

public class Ganzhiji2 {
	private double w[];
	private double b=0;
	private double a[];
	private double eta;
	ArrayList<Point> arrayList;
	
	public Ganzhiji2(ArrayList<Point> arrayList,double eta){
		this.arrayList=arrayList;
		w=new double[arrayList.get(0).x.length];
		a=new double[arrayList.size()];
		this.eta=eta;
	}
	
	public Ganzhiji2(ArrayList<Point> arrayList){
		this.arrayList=arrayList;
		w=new double[arrayList.get(0).x.length];
		a=new double[arrayList.size()];
		this.eta=1;
	}
	
	private double f(double x1[],double x2[]){//进行两个向量的内积计算
		double sum=0;
		for(int i=0;i<x1.length;i++){
			sum=sum+x1[i]*x2[i];
		}
		return sum;
	}
	
	private double g(ArrayList<Point> arrayList,int m){//用来判断模型
		double sum=0;
		for(int i=0;i<arrayList.size();i++){
			sum=sum+a[i]*arrayList.get(i).y*f(arrayList.get(i).x, arrayList.get(m).x);
		}
		return arrayList.get(m).y*(sum+b);
	}
	
	private void h(ArrayList<Point> arrayList,int m){//用来更新a和b
		a[m]=a[m]+eta;
		b=b+arrayList.get(m).y;
		
		System.out.print(a[0]+" "+a[1]+" "+a[2]+" "+b);
		System.out.println();
	}
	
	private void classify(){
		boolean flag=false;
		
		while(!flag){
			for(int i=0;i<arrayList.size();i++){
				if(g(arrayList, i)<=0){
					h(arrayList, i);
					break;
				}
				if(i==arrayList.size()-1){
					flag=true;
				}
			}
		}
		for(int i=0;i<arrayList.size();i++){
			double temp1=a[i]*arrayList.get(i).y;
			
			for(int j=0;j<arrayList.get(0).x.length;j++){
				if(j==0)
					w[j]+=arrayList.get(i).x[j]*temp1;
				else
					w[j]+=arrayList.get(i).x[j]*temp1;
			}
			
		}
		
		System.out.println(Arrays.toString(w));
		System.out.println(b);
	}

	public static void main(String[] args) {
		Point point1=new Point(new double[]{3, 3},1);
		Point point2=new Point(new double[]{4, 3},1);
		Point point3=new Point(new double[]{1, 1},-1);
		
		ArrayList<Point> arrayList=new ArrayList<>();
		arrayList.add(point1);
		arrayList.add(point2);
		arrayList.add(point3);
		
		Ganzhiji2 ganzhiji2=new Ganzhiji2(arrayList);
		ganzhiji2.classify();
	}

}

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值