import java.io.*; import java.util.Iterator; import java.util.LinkedList; public class TestKnn { public static void main(String[] args) throws IOException { Knn knn = new Knn("trainingDigits"); knn.addTest(new File("testDigits/2_1.txt")); knn.classifierKnn(); } } class Knn{ int[] train ; int[] test; LinkedList<MyFile> result= new LinkedList<>(); LinkedList<File> trains = new LinkedList<>(); Knn(String trainPath){ train = new int[32*33]; test =new int[32*33]; traverseFolder1(trainPath,trains); } public void delete(double index){ Iterator<MyFile> itre = result.iterator(); MyFile my ; while(itre.hasNext()){ my= itre.next(); if(my.distance==index){ itre.remove(); } } } public double findMan(){ double max=0.0; Iterator<MyFile> itre=result.iterator(); MyFile my; while(itre.hasNext()){ my=itre.next(); if(my.distance>max){ max=my.distance; } } return max; } public void classifierKnn() throws IOException { Iterator<File> it = trains.iterator(); int index=0; double distance=0; double min=0; int sum=0; int k=5; File f ; while(it.hasNext()){ f=it.next(); BufferedReader br = new BufferedReader(new FileReader(f)); String s; while((s=br.readLine())!=null){ char[] c =s.toCharArray(); for(int i=0;i<c.length;i++){ train[index*c.length+i]=Integer.parseInt(String.valueOf(c[i])); } distance=distance(test,train); index++; if(index==32){ index=0; } } if(sum==0){ min=distance; result.add(new MyFile(distance,f)); sum++; }else{ if(sum==k){ if(distance<=min){ double max=findMan(); delete(max); result.add(new MyFile(distance,f)); } }else if(sum<k){ result.add(new MyFile(distance,f)); sum++; } } } Iterator its = result.iterator(); MyFile my; while(its.hasNext()){ my= (MyFile) its.next(); System.out.println(my.file.getAbsolutePath()); } } public double distance(int[] a,int[] b){ double result=0.0; for(int i=0;i<a.length;i++){ if(a[i]==b[i]){ result+=0.0; }else{ result+=1.0; } } return Math.sqrt(result); } public void addTest(File file) throws IOException { BufferedReader br=new BufferedReader(new FileReader(file)); String s; int index=0; while((s=br.readLine())!=null){ char[] i = s.toCharArray(); for(int j=0;j<i.length;j++){ test[index*(i.length)+j]=Integer.parseInt(String.valueOf(i[j])); } index++; } } public void traverseFolder1(String path,LinkedList<File> list) { System.out.print("a"); File file = new File(path); if (file.exists()) { File[] files = file.listFiles(); for (File file2 : files) { if (file2.isFile()) { list.add(file2); } } } else { System.out.println("文件不存在!"); } } } class MyFile{ double distance; File file; MyFile(double f, File fi){ distance=f; file=fi; } }
模式识别-----knn算法(java实现)
最新推荐文章于 2023-10-11 17:02:39 发布