packageXBWKNN;importjava.io.IOException;importjava.util.ArrayList;importjava.util.Collections;importjava.util.Comparator;importjava.util.List;/*** KNN算法
*@authorXBW
* @date 2014年8月16日*/
public classXBWKNN{public final static int KofKNN=5;public final static double weight[]={1,0.9,0.7,0.4,0.1}; //减法函数y=1-0.2*x
/*** knn
*@paramdata
*@paramds
*@returnans*/
public static intknn(Data data,DataSet ds){int ans = 0;
List dis=calcDis(data,ds);
ans=calcKDis(data,dis);returnans;
}/*** 计算训练集中所有向量的距离,排序之后取前K个
*@paramdata
*@paramds
*@return
*/@SuppressWarnings("null")public static ListcalcDis(Data data,DataSet ds){
List anslist =new ArrayList();double dx1=data.x1;double dx2=data.x2;double dx3=data.x3;for(int i=0;i
ds.ds.get(i).costfun=Math.sqrt((dx1-x1)*(dx1-x1)+(dx2-x2)*(dx2-x2)+(dx3-x3)*(dx3-x3));
anslist.add(ds.ds.get(i));
}
Collections.sort(anslist,new Comparator(){public intcompare(Data o1, Data o2) {
Double s=o1.costfun-o2.costfun;if(s<0)return -1;else
return 1;
}
});returnanslist;
}/*** 按一定的权重计算出前K个
*@paramdata
*@paramds
*@return
*/
public static int calcKDis(Data data,Listanslist){
Double[] anstype={0.0,0.0,0.0,0.0};for(int i=0;i
anstype[1]+=weight[i];
}else if(anslist.get(i).type==2){
anstype[2]+=weight[i];
}if(anslist.get(i).type==3){
anstype[3]+=weight[i];
}
}
Double maxt=-1.0;int tag=1;for(int i=1;i<=3;i++){if(maxt
tag=i;
maxt=anstype[i];
}
}returntag;
}public static void main(String[] args) throwsIOException{
DataSet ds=newDataSet();
DataTest dt=newDataTest();int correct=0;for(int i=0;i
Data data=dt.dt.get(i);int result=knn(data,ds);if(result==data.type){
correct++;
}
}
System.out.println("total test num :"+dt.dt.size());
System.out.println("correct test num :"+correct);
System.out.println("ratio :"+correct/(double)dt.dt.size());
}
}