废话不多说了,这篇博文就是代码。
(1) 感知机学习算法的原式形式
package perceptron;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public class Perceptron {
public static void main(String[] args) throws IOException{
testPerceptron();
}
/*感知机算法:
* 主要内容:
* 1,损失函数。推导过程:其实就是误分类点到超平面距离之和,去掉分母
* 2,迭代方式:随机梯度下降法。选定合适的初始点(很难),负梯度方向,固定步长
* 3,感知机学习是以误分类驱动的。
* 4,感知机方程满足的有很多,改变点的输入顺序,可以导致不同的结果*/
public static void testPerceptron(){
//定义数据以及数据结构
ArrayList<Integer[]> arr = new ArrayList<Integer[]>();
Integer[] data1 = {3,3,1};
Integer[] data2 = {4,3,1};
Integer[] data3 = {1,1,-1};//每个数组的最后一个元素为类别标签,前两个元素为数据点
arr.add(data1);
arr.add(data2);
arr.add(data3);
//初始化w,b
double w1 = 0;
double w2 = 0;
double b = 0;
double e = 1;//步长
//进行数据集的选择,每个数据点可能被遍历好多次
iter(arr,w1,w2,b,e);
}
public static void iter(ArrayList<Integer[]> arr,double w1,double w2,double b,double e){ //这个函数是递归函数
String flag = "true";
for(int i=0;i<arr.size();i++){
double value = arr.get(i)[2]*(arr.get(i)[0]*w1+arr.get(i)[1]*w2+b);
if(value<=0){//判断是不是被误分类了
w1 = w1+e* arr.get(i)[0]*arr.get(i)[2];
w2 = w2+e* arr.get(i)[1]*arr.get(i)[2];
b = b+e*arr.get(i)[2];
flag = "false";
System.out.println(w1+"******"+w2+"*******"+b);
break; //如果参数发生了变化,那么伴随该参数的迭代马上终止。
}
else
continue;
}
if(flag.equals("false")){
iter(arr,w1,w2,b,e);
}
else{
System.out.println(w1+"******"+w2+"*******"+b);
}
}
}
(2)对偶形式
package perceptron;
import java.util.ArrayList;
import java.util.Arrays;
public class DualPerceptron {
public static void main(String[] args) {
testPerceptron();
}
public static void testPerceptron(){
ArrayList<Integer[]> arr = new ArrayList<Integer[]>();
Integer[] data1 = {3,3,1};
Integer[] data2 = {4,3,1};
Integer[] data3 = {1,1,-1};//每个数组的最后一个元素为类别标签,前两个元素为数据点
arr.add(data1);
arr.add(data2);
arr.add(data3);
int num =arr.size();
double[][] Gram = Gram( arr);//计算求得GRAM矩阵
double e = 1;
//初始化参数,全部为0
double[] paraArr = new double[arr.size()];
for(int i=0;i<num;i++){
paraArr[i] = 0;
}
double b = 0;
iter(arr,paraArr,b,Gram,e );
}
//
public static void iter(ArrayList<Integer[]> arr,double[] paraArr,double b ,double[][] Gram,double e ){ //这个函数是递归函数
String flag = "true";
for(int i=0;i<arr.size();i++){//从数据集中选取某个点
double value = 0;
for(int j=0;j<arr.size();j++){
value = value+paraArr[j]*e*arr.get(j)[2]*Gram[j][i];
}
value = arr.get(i)[2]*(value+b);
if(value<=0){
flag = "false";
paraArr[i] = paraArr[i]+e;
b = b+e*arr.get(i)[2];
break;
}
else
continue;
}
System.out.println(flag);
if(flag.equals("true")){
System.out.println(Arrays.toString(paraArr)+" "+b);
}
else{
iter( arr,paraArr, b, Gram, e );
}
}
//函数功能:计算Gram矩阵。格拉姆矩阵
public static double[][] Gram(ArrayList<Integer[]> arr){
double[][] gram = new double[arr.size()][arr.size()];
for(int i=0;i<arr.size();i++){
for(int j=0;j<arr.size();j++){
gram[i][j]=arr.get(i)[0]*arr.get(j)[0]+arr.get(i)[1]*arr.get(j)[1];
}
}
return gram;
}
}