机器学习实战KNN的java实现

4 篇文章 0 订阅
1 篇文章 0 订阅
这个博客展示了如何使用Java实现KNN(K最近邻)机器学习算法,包括距离计算、排序、数据处理和核心分类算法。代码示例中包含了读取文件、归一化处理以及简单的测试案例,如约会测试和手写识别。
摘要由CSDN通过智能技术生成
package com.haolidong.kNN;

import java.util.Comparator;
/**
 * 
 * @author haolidong
 * @Description:  [该类主要用于对距离信息的index进行自定义的排序(从大到小)]  
 */
public class ComparatorImpl implements Comparator<Distances>{

	@Override
	public int compare(Distances arg0, Distances arg1) {
		// TODO Auto-generated method stub
		double d0=arg0.getDistances();
		double d1=arg1.getDistances();
		if(d0>d1){
			return 1;
		}
		else if(d0<d1){
			return -1;
		}
		else{
			return 0;
		}
		
	}
	
}

package com.haolidong.kNN;
/**
 * 
 * @author haolidong
 * @Description:  [该类主要用于保存KNN的距离信息以及index]  
 */
public class Distances {
	double distances;
	public Distances()
	{
		distances=0.0;
		sortedDistIndicies=0;
	}
	public Distances(double distances, int sortedDistIndicies) {
		super();
		this.distances = distances;
		this.sortedDistIndicies = sortedDistIndicies;
	}
	int sortedDistIndicies;
	public double getDistances() {
		return distances;
	}
	public void setDistances(double distances) {
		this.distances = distances;
	}
	public int getSortedDistIndicies() {
		return sortedDistIndicies;
	}
	public void setSortedDistIndicies(int sortedDistIndicies) {
		this.sortedDistIndicies = sortedDistIndicies;
	}	
}
package com.haolidong.kNN;

import java.util.ArrayList;

/**
 * @author haolidong
 * @Description:  [该类主要用于保存信息矩阵以及矩阵标签]
 */
public class ReturnML {
	public ArrayList<ArrayList<Double>> AR;
	public ArrayList<String> AS;	
	public ReturnML() {
		// TODO Auto-generated constructor stub
		AR = new ArrayList<ArrayList<Double>>();
		AS = new ArrayList<String>();
	}
}
package com.haolidong.kNN;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map.Entry;
import java.util.Set;

public class KNN {

	public static ReturnML rml = new ReturnML();
	

	/**
	 * @param args
	 * @throws IOException
	 * @author haolidong
	 * @Description:  [主函数主要对于三个案例进行测试,分别为简单分类,约会测试以及手写识别]  
	 */
	public static void main(String[] args) throws IOException {

		testClassify();
		datingClassTest();
		handwritingClassTest();
	}
	
	/**
	 * @author haolidong
	 * @Description:  [简单的通过文本文件创建二维矩阵并输出]  
	 */
	public static void test(){
		file2matrix("I:\\10yue1\\machinelearninginaction\\Ch02\\datingTestSet2.txt");
		autoNorm();
		for (int i = 0; i < rml.AR.size(); i++) {
			System.out.print(i+":    ");
			for (int j = 0; j < rml.AR.get(i).size(); j++) {
				System.out.print(rml.AR.get(i).get(j)+"          ");
			}
			System.out.println(rml.AS.get(i));
		}
	}
	/**
	 * @return 返回标签号
	 * @author haolidong
	 * @Description:  [函数主要对于KNN的简单分类]  
	 */
	public static String testClassify(){
		ArrayList<ArrayList<Double>> group = new ArrayList<ArrayList<Double>>();
		ArrayList<String> labels = new ArrayList<String>();
		ArrayList<Double> input = new ArrayList<Double>();
		input.add(0.0);
		input.add(0.0);
		ArrayList<Double> a1 = new ArrayList<Double>();
		a1.add(1.0);
		a1.add(1.1);
		ArrayList<Double> a2 = new ArrayList<Double>();
		a2.add(1.0);
		a2.add(1.0);
		ArrayList<Double> a3 = new ArrayList<Double>();
		a3.add(0.0);
		a3.add(0.0);
		ArrayList<Double> a4 = new ArrayList<Double>();
		a4.add(0.0);
		a4.add(0.1);
		group.add(a1);
		group.add(a2);
		group.add(a3);
		group.add(a4);
		labels.add("A");
		labels.add("A");
		labels.add("B");
		labels.add("B");
		String lab=classify(input,group,labels,3);
		System.out.println(lab);
		return lab;
	}

	/**
	 * @param inX     测试用例的输入
	 * @param dataSet 训练数据矩阵
	 * @param labels  训练数据标签
	 * @param k       kNN中对于前面K项的排名
	 * @return        测试用例的标签
	 * @author haolidong
	 * @Description:  [KNN的核心分类算法]  
	 */
	public static String classify(ArrayList<Double> inX,ArrayList<ArrayList<Double>> dataSet,ArrayList<String> labels,int k)
	{
		ArrayList<ArrayList<Double>> dataCopy = new ArrayList<ArrayList<Double>>();
		for (int i = 0; i < dataSet.size(); i++) {
			ArrayList<Double> ad = new ArrayList<>();
			for (int j = 0; j < dataSet.get(i).size(); j++) {
				ad.add(dataSet.get(i).get(j));
			}
			dataCopy.add(ad);
		}
		ArrayList<Distances> dis = new ArrayList<Distances>();
		
		for (int i = 0; i < dataCopy.size(); i++) {
			dis.add(new Distances());
		}		
		for (int i = 0; i < dataCopy.size(); i++) {
			for (int j = 0; j < dataCopy.get(i).size(); j++) {
				dataCopy.get(i).set(j, inX.get(j)-dataCopy.get(i).get(j));
			}
		}
		for (int i = 0; i < dataCopy.size(); i++) {
			dis.get(i).setSortedDistIndicies(i);
			double distan = 0.0;
			for (int j = 0; j < dataCopy.get(i).size(); j++) {
				distan = distan + dataCopy.get(i).get(j)*dataCopy.get(i).get(j);
			}
			dis.get(i).setDistances(Math.sqrt(distan));
		}
		Comparator<Distances> comp = new ComparatorImpl();
		Collections.sort(dis, comp);		
		HashMap<String,Integer> classCount = new HashMap<String,Integer>();
		String voteIlabel;
		for (int i = 0; i < k; i++) {
			voteIlabel = labels.get(dis.get(i).getSortedDistIndicies());
			if(classCount.containsKey(voteIlabel)==true){
				classCount.put(voteIlabel, classCount.get(voteIlabel)+1);
			}else{
				classCount.put(voteIlabel, 1);
			}
		}
		classCount = sortMap(classCount);
		Set<Entry<String, Integer>> set = classCount.entrySet();
		Iterator<Entry<String, Integer>> it = set.iterator();	
		return (String) it.next().getKey();
		
	}
	
	public static HashMap<String,Integer> sortMap(HashMap<String,Integer> oldMap) {  
        ArrayList<HashMap.Entry<String, Integer>> list = new ArrayList<HashMap.Entry<String, Integer>>(oldMap.entrySet());  
        Collections.sort(list, new Comparator<HashMap.Entry<String, Integer>>() {  
  
            @Override  
            public int compare(Entry<java.lang.String, Integer> arg0,  
                    Entry<java.lang.String, Integer> arg1) {  
                return arg1.getValue() - arg0.getValue();  
            }  
        });
        HashMap<String, Integer> newMap = new LinkedHashMap<String, Integer>();  
        for (int i = 0; i < list.size(); i++) {  
            newMap.put(list.get(i).getKey(), list.get(i).getValue());  
        }  
        return newMap;   
	}
	
	
	/**
	 * @param fileName  文件名
	 * @author haolidong
	 * @Description:  [读入文件然后转化为数组矩阵]  
	 */
	public static void file2matrix(String fileName){
		 	File file = new File(fileName);
	        BufferedReader reader = null;
	        try {
	            reader = new BufferedReader(new FileReader(file));
	            String tempString = null;
	            // 一次读入一行,直到读入null为文件结束
	            while ((tempString = reader.readLine()) != null) {
	                // 显示行号
	                String[] strArr = tempString.split("\t");
	                ArrayList<Double> ad = new ArrayList<Double>();
	                for (int i = 0; i < strArr.length-1; i++) {
						ad.add(Double.parseDouble(strArr[i]));
					}
	                rml.AR.add(ad);
	                rml.AS.add(new String(strArr[strArr.length-1]));
	            }
	            reader.close();
	        } catch (IOException e) {
	            e.printStackTrace();
	        } finally {
	            if (reader != null) {
	                try {
	                    reader.close();
	                } catch (IOException e1) {
	                }
	            }
	        }
	}
	
	/**
	 * @author haolidong
	 * @Description:  [对于输入矩阵的归一化:X:(X-min)/(max-min)] 
	 */
	public static void autoNorm(){
		ArrayList<Double> min = new ArrayList<Double>();
		ArrayList<Double> max = new ArrayList<Double>();
		ArrayList<Double> range = new ArrayList<Double>();
		for (int j = 0; j < rml.AR.get(0).size(); j++) {
			min.add(rml.AR.get(0).get(j));
			max.add(rml.AR.get(0).get(j));
		}
		for (int i = 0; i < rml.AR.size(); i++) {
			for (int j = 0; j < rml.AR.get(i).size(); j++) {
				if(rml.AR.get(i).get(j)>max.get(j)){
					max.set(j, rml.AR.get(i).get(j));
				}
				if(rml.AR.get(i).get(j)<min.get(j)){
					min.set(j, rml.AR.get(i).get(j));
				}
			}
		}
		for (int j = 0; j < rml.AR.get(0).size(); j++) {
			range.add(max.get(j)-min.get(j));
		}
		for (int i = 0; i < rml.AR.size(); i++) {
			for (int j = 0; j < rml.AR.get(i).size(); j++) {
				rml.AR.get(i).set(j, (rml.AR.get(i).get(j)-min.get(j))/range.get(j));
			}
		}
	}
	/**
	 * @author haolidong
	 * @Description:  [约会的分类案例]  
	 */
	public static void datingClassTest(){
		double hoRatio = 0.50;
		file2matrix("I:\\10yue1\\machinelearninginaction\\Ch02\\datingTestSet2.txt");
		autoNorm();
		int m = rml.AR.size();
		int numTestVecs = (int) (m*hoRatio);
		
		ArrayList<ArrayList<Double>> group = new ArrayList<ArrayList<Double>>();
		ArrayList<String> labels = new ArrayList<String>();
		autoNorm();
		for (int i = 0; i < rml.AR.size()-numTestVecs; i++) {
			ArrayList<Double> ad = new ArrayList<Double>();
			for (int j = 0; j < rml.AR.get(i).size(); j++) {
				ad.add(rml.AR.get(i+numTestVecs).get(j));
			}
			group.add(ad);
			labels.add(rml.AS.get(i+numTestVecs));
		}
		int errorCount1 = 0;
		int s1,s2;
		for (int i = 0; i < numTestVecs; i++) {
			 s1=Integer.parseInt(classify(rml.AR.get(i),group,labels,3));
			 s2=Integer.parseInt(rml.AS.get(i).trim());
			 System.out.println("the classifier came back with: "+s1+"  the real answer is: "+s2);
			if(s1!=s2)
			{
				errorCount1++;
			}		
		}
		System.out.println("the total error rate is:"+1.0*errorCount1/numTestVecs);
	}
	
	/**
	 * @param file 输入的二进制图片文件
	 * @return 返回图像矩阵
	 * @author haolidong
	 * @Description:  [二进制图片文件转化为图像矩阵] 
	 */
	public static ArrayList<Double> img2vector(File file){
		ArrayList<Double> ad = new ArrayList<Double>();
		//File file = new File(fileName);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = null;
            
            while ((tempString = reader.readLine()) != null) { 
            	char[] ch=tempString.toCharArray();
                for (int i = 0; i < ch.length; i++) {
                	double d=Integer.parseInt(ch[i]+"");
                	ad.add(d);
					//ad.add(1.0*);
				}
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
		return ad;
		
	}
	
	/**
	 * @throws IOException
	 * @author haolidong
	 * @Description:  [手写识别] 
	 */
	public static void handwritingClassTest() throws IOException
	{
		
		ArrayList<ArrayList<Double>> vectorUnderTest = new ArrayList<ArrayList<Double>>();
		ArrayList<ArrayList<Double>> trainingMat = new ArrayList<ArrayList<Double>>();
		ArrayList<String> trainingLabel = new ArrayList<String>();
		ArrayList<String> testLabel = new ArrayList<String>();
		
		 String pathTest="I:\\machinelearninginaction\\Ch02\\testDigits\\";
		 String pathTrain="I:\\machinelearninginaction\\Ch02\\trainingDigits\\";
		  File fileTrain=new File(pathTrain);
		  File fileTest=new File(pathTest);
		  File[] trainList = fileTrain.listFiles();
		  File[] testList = fileTest.listFiles();
		  String tmp = new String();
		  for (int i = 0; i < trainList.length; i++) {
			  if (trainList[i].isFile()) {
				  ArrayList<Double> ad = new ArrayList<Double>();
				  ad = img2vector(trainList[i]);
				  tmp=trainList[i].getCanonicalPath();
				  trainingLabel.add(tmp.substring(tmp.lastIndexOf('\\')+1, tmp.indexOf('_')));
				  trainingMat.add(ad);
			  }
		  }
		  for (int i = 0; i < testList.length; i++) {
			  if (testList[i].isFile()) {
				  ArrayList<Double> adt = new ArrayList<Double>();
				  adt = img2vector(testList[i]);
				  tmp=testList[i].getCanonicalPath();
				  testLabel.add(tmp.substring(tmp.lastIndexOf('\\')+1, tmp.indexOf('_')));
				  vectorUnderTest.add(adt);
			  }
		  }
		  String classifyLabel;
		  int errorCount = 0;
		  for (int i = 0; i < testList.length; i++) {
			  classifyLabel = classify(vectorUnderTest.get(i), trainingMat, trainingLabel, 3);
			  System.out.println("the classifier came back with:"+classifyLabel+", the real answer is:"+testLabel.get(i));
			  if(!classifyLabel.equals(testLabel.get(i))){
				  errorCount++;
			  }
		}
		System.out.println("the total number of errors is:"+errorCount);
		System.out.println("the total error rate is: "+1.0*errorCount/testList.length);
	}
}



评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值