感知机算法(二)---代码实现

  废话不多说了,这篇博文就是代码。


(1) 感知机学习算法的原式形式

package perceptron;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

public class Perceptron {
	public static void main(String[] args) throws IOException{
		
		 testPerceptron();
	}
	
	/*感知机算法:
	 * 主要内容:   
	 * 1,损失函数。推导过程:其实就是误分类点到超平面距离之和,去掉分母
	 * 2,迭代方式:随机梯度下降法。选定合适的初始点(很难),负梯度方向,固定步长
	 * 3,感知机学习是以误分类驱动的。
	 * 4,感知机方程满足的有很多,改变点的输入顺序,可以导致不同的结果*/
	
	public static void  testPerceptron(){
		//定义数据以及数据结构
		ArrayList<Integer[]> arr = new ArrayList<Integer[]>();
		Integer[] data1 = {3,3,1};
		Integer[] data2 = {4,3,1};
		Integer[] data3 = {1,1,-1};//每个数组的最后一个元素为类别标签,前两个元素为数据点
		
		arr.add(data1);
		arr.add(data2);
		arr.add(data3);

		//初始化w,b
		double w1 = 0;
		double w2 = 0;
		double b = 0;
		double e = 1;//步长
		
		//进行数据集的选择,每个数据点可能被遍历好多次
		
		iter(arr,w1,w2,b,e);
	
		
	}
	public static void iter(ArrayList<Integer[]> arr,double w1,double w2,double b,double  e){ //这个函数是递归函数
		
		String flag = "true";
		for(int i=0;i<arr.size();i++){
			double value = arr.get(i)[2]*(arr.get(i)[0]*w1+arr.get(i)[1]*w2+b);
			if(value<=0){//判断是不是被误分类了
				w1 = w1+e* arr.get(i)[0]*arr.get(i)[2];
				w2 = w2+e* arr.get(i)[1]*arr.get(i)[2];
				b = b+e*arr.get(i)[2];
				flag = "false";
				System.out.println(w1+"******"+w2+"*******"+b);
				break; //如果参数发生了变化,那么伴随该参数的迭代马上终止。
			}
			else
				continue;
		}
		if(flag.equals("false")){
			iter(arr,w1,w2,b,e);
		}
		else{
			System.out.println(w1+"******"+w2+"*******"+b);
		}
	}
}
	


(2)对偶形式

package perceptron;

import java.util.ArrayList;
import java.util.Arrays;

public class DualPerceptron {
	public static void main(String[] args) {
		
		testPerceptron();
	}
	
	public static void testPerceptron(){
		ArrayList<Integer[]> arr = new ArrayList<Integer[]>();
		Integer[] data1 = {3,3,1};
		Integer[] data2 = {4,3,1};
		Integer[] data3 = {1,1,-1};//每个数组的最后一个元素为类别标签,前两个元素为数据点
		
		arr.add(data1);
		arr.add(data2);
		arr.add(data3);
		
		int num =arr.size();
		
		double[][] Gram = Gram( arr);//计算求得GRAM矩阵
		
		double e = 1;
		//初始化参数,全部为0
		double[] paraArr = new double[arr.size()];
		for(int i=0;i<num;i++){
			paraArr[i] = 0;
		}
		double b = 0;
		iter(arr,paraArr,b,Gram,e );
		
	}
	
	//
	public static void iter(ArrayList<Integer[]> arr,double[] paraArr,double b ,double[][] Gram,double e ){ //这个函数是递归函数
		
		String flag = "true";
		for(int i=0;i<arr.size();i++){//从数据集中选取某个点
			
			double value = 0;
			
			for(int j=0;j<arr.size();j++){
				value = value+paraArr[j]*e*arr.get(j)[2]*Gram[j][i];
			}
			
			value = arr.get(i)[2]*(value+b);
			if(value<=0){
				flag = "false";
				paraArr[i] = paraArr[i]+e;
				b = b+e*arr.get(i)[2];
				break;
			}
			else
				continue;

		}
		System.out.println(flag);
		if(flag.equals("true")){
			System.out.println(Arrays.toString(paraArr)+"  "+b);
		}
		else{
			iter( arr,paraArr, b, Gram, e );
		}	
	}
	
	//函数功能:计算Gram矩阵。格拉姆矩阵
	public static double[][] Gram(ArrayList<Integer[]> arr){
		double[][] gram = new double[arr.size()][arr.size()];
		for(int i=0;i<arr.size();i++){
			for(int j=0;j<arr.size();j++){
				gram[i][j]=arr.get(i)[0]*arr.get(j)[0]+arr.get(i)[1]*arr.get(j)[1];
			}
		}
		return gram;
	}

}


  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值