K最近邻(KNN)算法原理和java实现

1.KNN算法原理:

(1)基于类比原理,通过比较训练元组和测试元组的相似度来学习

        (2)将训练元组和测试元组看作是n维空间内的点,给定一条测试元组,搜索n维空间,找出与测试元组最相近的k个点(即训练元组),最后取这k个点中的多数类做作为测试元组 的类别。

(3)相近的度量方法:用空间内两个点的距离来度量,距离越大,表示两个点越不相似。

(4)距离的选择:可采用欧几里得距离,曼哈顿距离,等其他度量方法,一般采用欧几里得距离,比较简单。

2.KNN算法中的细节处理

(1)数值属性规范化:将数值属性规范到0-1区间以便于计算,也可防止大数值型属性对分类的主导。

(2)可选的方法有:v‘=(v-vmin)/(vmax-vmin),也有其他方法。

(3)比较的数据是分类类型而不是数值类型,同则差为0,异则差为1,有时候可以做更精确的处理,比如黑色和白色的差肯定大于灰色和白色的差。

(4)缺失值的处理(不理解,以后理解):取最大的可能差,对于分类属性,如果属性A的一个或两个对应值丢失,则取差值为1;如果A是数值属性,若两个比较的元组A属性值均缺失,则取差值为1,若只有一个缺失,另一个值为v,则取差值为|1-v|和|0-v|中的最大值。

(5)确定K的值:通过实验确定。进行若干次实验,取分类误差率最小的k值。

(7)对噪声数据或不相关属性的处理:对属性赋予相关性权重w,w越大说明属性对分类的影响越相关。对噪声数据可以将所在的元组直接cut掉。

3.KNN算法流程

 1)准备数据,对数据进行预处理

 2)选用合适的数据结构存储训练数据和测试元组 

         3)设定参数,如k

 4)维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元        组的距离,将训练元组标号和距离存入优先级队列 
 5)遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L与优先级队列中的最大距离Lmax进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L       < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。 

         6)遍历完毕,计算优先级队列中k个元组的多数类,并将其作为测试元组的类别。

         7)测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k值。

4.Java实现

KNN结点类,用来存储最近邻的k个元组相关的信息

/** 
 * KNN结点类,用来存储最近邻的k个元组相关的信息 
 */  
public class KNNNode {  
    private int index;          // 元组标号  
    private double distance;    // 与测试元组的距离  
    private String c;           // 所属类别  
    public KNNNode(int index, double distance, String c) {  
        super();  
        this.index = index;  
        this.distance = distance;  
        this.c = c;  
    }  
      
      
    public int getIndex() {  
        return index;  
    }  
    public void setIndex(int index) {  
        this.index = index;  
    }  
    public double getDistance() {  
        return distance;  
    }  
    public void setDistance(double distance) {  
        this.distance = distance;  
    }  
    public String getC() {  
        return c;  
    }  
    public void setC(String c) {  
        this.c = c;  
    }  
}  
KNN算法主体类

/** 
 * KNN算法主体类 
 */  
public class KNN {  
    /** 
     * 设置优先级队列的比较函数,距离越大,优先级越高 
     */  
    private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {  
        public int compare(KNNNode o1, KNNNode o2) {  
            if (o1.getDistance() >= o2.getDistance()) {  
                return 1;  
            } else {  
                return 0;  
            }  
        }  
    };  
    /** 
     * 获取K个不同的随机数 
     * @param k 随机数的个数 
     * @param max 随机数最大的范围 
     * @return 生成的随机数数组 
     */  
    public List<Integer> getRandKNum(int k, int max) {  
        List<Integer> rand = new ArrayList<Integer>(k);  
        for (int i = 0; i < k; i++) {  
            int temp = (int) (Math.random() * max);  
            if (!rand.contains(temp)) {  
                rand.add(temp);  
            } else {  
                i--;  
            }  
        }  
        return rand;  
    }  
    /** 
     * 计算测试元组与训练元组之前的距离 
     * @param d1 测试元组 
     * @param d2 训练元组 
     * @return 距离值 
     */  
    public double calDistance(List<Double> d1, List<Double> d2) {  
        double distance = 0.00;  
        for (int i = 0; i < d1.size(); i++) {  
            distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));  
        }  
        return distance;  
    }  
    /** 
     * 执行KNN算法,获取测试元组的类别 
     * @param datas 训练数据集 
     * @param testData 测试元组 
     * @param k 设定的K值 
     * @return 测试元组的类别 
     */  
    public String knn(List<List<Double>> datas, List<Double> testData, int k) {  
        PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);  
        List<Integer> randNum = getRandKNum(k, datas.size());  
        for (int i = 0; i < k; i++) {  
            int index = randNum.get(i);  
            List<Double> currData = datas.get(index);  
            String c = currData.get(currData.size() - 1).toString();  
            KNNNode node = new KNNNode(index, calDistance(testData, currData), c);  
            pq.add(node);  
        }  
        for (int i = 0; i < datas.size(); i++) {  
            List<Double> t = datas.get(i);  
            double distance = calDistance(testData, t);  
            KNNNode top = pq.peek();  
            if (top.getDistance() > distance) {  
                pq.remove();  
                pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));  
            }  
        }  
          
        return getMostClass(pq);  
    }  
    /** 
     * 获取所得到的k个最近邻元组的多数类 
     * @param pq 存储k个最近近邻元组的优先级队列 
     * @return 多数类的名称 
     */  
    private String getMostClass(PriorityQueue<KNNNode> pq) {  
        Map<String, Integer> classCount = new HashMap<String, Integer>();  
        for (int i = 0; i < pq.size(); i++) {  
            KNNNode node = pq.remove();  
            String c = node.getC();  
            if (classCount.containsKey(c)) {  
                classCount.put(c, classCount.get(c) + 1);  
            } else {  
                classCount.put(c, 1);  
            }  
        }  
        int maxIndex = -1;  
        int maxCount = 0;  
        Object[] classes = classCount.keySet().toArray();  
        for (int i = 0; i < classes.length; i++) {  
            if (classCount.get(classes[i]) > maxCount) {  
                maxIndex = i;  
                maxCount = classCount.get(classes[i]);  
            }  
        }  
        return classes[maxIndex].toString();  
    }  
}  
KNN算法测试类

package feature;

import java.io.BufferedReader;
import java.io.*;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;

// KNN算法测试类
public class TestKNN {
	/**
	 * 从数据文件中读取数据
	 * 
	 * @param datas
	 *            存储数据的集合对象
	 * @param path
	 *            数据文件的路径
	 */
	public void read(List<List<Double>> datas, String path) {
		try {
			BufferedReader br = new BufferedReader(new FileReader(
					new File(path)));
			String reader = br.readLine();
			while (reader != null) {
				String t[] = reader.split("\t");
				ArrayList<Double> list = new ArrayList<Double>();
				for (int i = 0; i < t.length; i++) {
					list.add(Double.parseDouble(t[i]));
				}
				datas.add(list);
				reader = br.readLine();
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	/**
	 * 程序执行入口
	 * 
	 * @param args
	 */
	public static void main(String[] args) {
		TestKNN t = new TestKNN();
		String datafile="e:\\work\\交接数据\\testData\\KNN\\datafile.txt";
		String testfile="e:\\work\\交接数据\\testData\\KNN\\testfile.txt";
		try {
			List<List<Double>> datas = new ArrayList<List<Double>>();
			List<List<Double>> testDatas = new ArrayList<List<Double>>();
			t.read(datas, datafile);
			t.read(testDatas, testfile);
			KNN knn = new KNN();
			FileWriter fw=new FileWriter("e:\\work\\交接数据\\testData\\KNN\\result.txt");
			for (int i = 0; i < testDatas.size(); i++) {
				List<Double> test = testDatas.get(i);
				System.out.print("测试元组: ");
				for (int j = 0; j < test.size(); j++) {
					
					fw.write(test.get(j) + "\t");
					fw.flush();
					System.out.print(test.get(j) + "\t");
				}
				System.out.print("类别为: ");
				fw.write(Math.round(Float.parseFloat((knn.knn(datas,
						test, 3))))+"\n");
				fw.flush();
				System.out.println(Math.round(Float.parseFloat((knn.knn(datas,
						test, 3)))));
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
}
训练数据:

1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0

测试数据:

1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5

结果是:

测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1
测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1
测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1
测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0
测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1
测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0

5.错误及解决办法

在测试KNN算法时出现错误

java.lang.NumberFormatException: For input string: "1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0"
		at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1224)
		at java.lang.Double.parseDouble(Double.java:510)
		at feature.TestKNN.read(TestKNN.java:29)
		at feature.TestKNN.main(TestKNN.java:53)
	java.lang.NumberFormatException: For input string: "1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5"
		at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1224)
		at java.lang.Double.parseDouble(Double.java:510)
		at feature.TestKNN.read(TestKNN.java:29)
		at feature.TestKNN.main(TestKNN.java:54)
错误原因是发现read方法中出错了
public void read(List<List<Double>> datas, String path) {
		try {
			BufferedReader br = new BufferedReader(new FileReader(
					new File(path)));
			String reader = br.readLine();
			while (reader != null) {
				String t[] = reader.split("\t");
				ArrayList<Double> list = new ArrayList<Double>();
				for (int i = 0; i < t.length; i++) {
					list.add(Double.parseDouble(t[i]));
				}
				datas.add(list);
				reader = br.readLine();
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
}
出错行是list.add(Double.parseDouble(t[i]));
经分析发现输入数据每个数据之间的间隔是空格,不是tab,而程序行String t[] = reader.split("\t");说明数据
之间需要用tab键隔开。解决办法有两种:
第一:把程序行String t[] = reader.split("\t");改成String t[] = reader.split(" ");
第二:将输入数据之间的空格键改成tab键。


取最大的可能差,对于分类属性,如果属性

A

的一个或两个对应值丢失,则取差值为

1

如果

A

是数值属性,若两个比较的元组

A

属性值均缺失,则取差值为

1

,若只有一个缺失,另一个值为

v

则取差值为|

1-v

|和|

0-v

|中的最大值

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值