KNN 近邻算法

- 指导思想

先计算待分类样本与已知类别的训练样本之间的距离,找到距离与待分类样本数据最近的k个邻居;再根据这些邻居所属类别来判断待分类样本数据的类别。

- 算法步骤

  1. 计算距离:给定测试对象,计算它与训练对象集中每个对象的距离。
  2. 找邻居:选自距离最近的k个训练对象,作为测试对象的近邻。
  3. 做分类:根据K个近邻的主要类别,来对测试对象进行分类。

优缺点

优点

  1. 简单、易于理解以及实现
  2. 适合对稀有事件进行分类
  3. 特别适合与多分类问题(对象具有多个类别标签)

缺点

  1. 对测试样本分类时的计算量大、内存开销大,评分慢。
  2. 可解释性差,没有决策树强
  3. 特别适合与多分类问题(对象具有多个类别标签)

算法流程

  • 第一步:

    创建大小为K的,按照测试元组与既存元组之间的距离由大到小的优先级队列,用于存储最近邻训练数据。
    随机在既存元组中选取k个元组,作为最初始的最近邻元组,分别计算测试元组到这K个既存元组之间的距离,并将既存元组的标号和距离追加到优先级队列。

  • 第二步:

    遍历既存元组集,计算当前的既存元组与测试元组之间的距离,将所得的距离(L)与priorityQueue中的最大距离(Lmax)进行比较。
    若L > Lmax, 则舍弃该元组遍历下一个元组。若L < Lmax,则删除priorityQueue中距离最大的元组,将当前既存元组存入priorityQueue中。

  • 第三步:

    遍历完毕以后,计算优先级队列中k个最近邻元组的多数类,并将其作为测试元组的类别

算法改进

  • 通过降维技术来减少维数,如主成分分析,因子分析,变量选择(因子选择)从而减少计距离的时间。
  • 用复杂的数据结构,如搜索树去加速最近邻的确定。
  • 编辑训练数据去减少在训练集中的冗余和几乎是冗余的点,从而加速搜索最近邻。

JavaDemo的实现

创建KNN算法的是实现类

package knn;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

/**
 * 
 * @author FeiYue
 * KNN 算法 的实现
 */
public class KNN {
    /**
     * 设置优先级比较函数,距离越大,优先级越高
     */

    private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {

        @Override
        public int compare(KNNNode knnode1, KNNNode knnode2) {
            if (knnode1.getDistance() >= knnode2.getDistance()) {
                return -1;
            } 
            return 0;
        }

    };

    /**
     * 获取K个不同的随机数
     * @param k 随机数的范围
     * @param max  随机数的最大范围
     * @return 生成随机数的数组
     */
    public List<Integer> getRandKNum(int k, int max) {
        List<Integer> rand = new ArrayList<>();
        System.out.println();
        System.out.print(k + "个随机数:");
        for (int i = 0; i < k; i++) {
            int temp = (int) (Math.random() * max);

            /**
             * 随机数集合中如果包含temp,则不追加该数,重新做成。
             */
            if (!rand.contains(temp)) {
                System.out.print(temp + " ");
                rand.add(temp);
            } else {
                i--;
            }
        }
        System.out.println();
        return rand;    
    }

    /**
     * 计算测试元组与既存元组之间的距离
     * 本实例采用的距离计算公式是欧式距离计算
     * 0ρ = sqrt( (x1-x2)^2+(y1-y2)^2 )   |x| = √( x2 + y2 )
     * @param d1 测试元组
     * @param d2 既存元组
     * @return 距离值
     */
    public double calculateDistance (List<Double> d1, List<Double> d2) {
        double distance = 0.00;
        System.out.print("测试数据与既存数据之间的距离 :");

        for (int i = 0; i < d1.size(); i++) {
            double temp = Math.pow((d1.get(i) - d2.get(i)), 2);         
            distance += temp;
        }

        distance = Math.sqrt(distance);
        System.out.println(distance);

        return distance;

    }

    /**
     * 执行KNN算法,获取测试元组的类别
     * @param datas  既存数据
     * @param testData  测试数据
     * @param k 设定的K值
     * @return 测试数据的类别
     */

    public String knn(List<List<Double>> datas, List<Double> testData, int k) {

        /**
         *第一步:
         *创建大小为K的,按照测试元组与既存元组之间的距离由大到小的优先级队列,用于存储最近邻训练数据。
         *随机在既存元组中选取k个元组,作为最初始的最近邻元组,分别计算测试元组到这K个既存元组之间的距离,
         *并将既存元组的标号和距离追加到优先级队列。
         */
        System.out.println("\n——————第一步——————");

        //创建大小为K的,按照测试元组与既存元组之间的距离由大到小的优先级队列,用于存储最近邻训练数据
        PriorityQueue<KNNNode> priorityQueue = new PriorityQueue<>(k, comparator);

        //随机在既存元组中选取k个元组,作为最初始的最近邻元组
        List<Integer> randNum = getRandKNum(k, datas.size());

        //分别计算测试元组到这K个既存元组之间的距离,并将其追加到priorityQueue
        for (int i = 0; i < k; i++) {
            int index = randNum.get(i);
            List<Double> currentData = datas.get(index);        
            String category = currentData.get(currentData.size() - 1) .toString();
            System.out.println("第" + index + "组既存数据的类别为:" + category);

            //创建KNNode,并且将该节点追加到priorityQueue队列中
            // KNNnode节点保存了测试元组到既存元组之间的距离以及既存元组的编号和类别           
            KNNNode node = new KNNNode(index, calculateDistance(testData, currentData), category);
            priorityQueue.add(node);
        }

        /**
         * 第二步:
         * 遍历既存元组集,计算当前的既存元组与测试元组之间的距离,
         * 将所得的距离(L)与priorityQueue中的最大距离(Lmax)进行比较。
         * 若L > Lmax, 则舍弃该元组遍历下一个元组。
         * 若L < Lmax,则删除priorityQueue中距离最大的元组,将当前既存元组存入priorityQueue中。
         */
        System.out.println("\n——————第二步——————");
        System.out.println("------既存元组与测试元组之间的距离 start--------");
        for (int i = 0; i < datas.size(); i++) {

            //获取当前的既存元组
            List<Double> currentDatas =datas.get(i);

            //计算当前的既存元组与测试元组之间的距离           
            double distance = calculateDistance(testData, currentDatas);

            //获取priorityQueue队头元组
            KNNNode top = priorityQueue.peek();

            //L < Lmax
            if (distance < top.getDistance() ) {
                //删除该元组
                priorityQueue.remove();

                //将当前既存元组存入priorityQueue中
                priorityQueue.add(new KNNNode(i, distance, currentDatas.get(currentDatas.size()-1).toString()));
            }
        }
        System.out.println("------既存元组与测试元组之间的距离 end--------"); 
        return getMostClass(priorityQueue);

    }

    /**
     * 获取得到的k个最近邻元组的多数类
     * @param priorityQueue 存储k个最近邻元组的优先级队列
     * @return 多数类的类别
     */

    private String getMostClass(PriorityQueue<KNNNode> priorityQueue) {
        System.out.println("\n——————第三步——————");
        //创建Map对象,用于计算各个类别出现的个数。
        Map<String, Integer> classCount = new HashMap<String, Integer>();
        int priorityQueueLength = priorityQueue.size();

        //统计优先级队列中,各个类别的个数,并将其追加到classCount中。
        for (int i = 0; i < priorityQueueLength; i++) {
            //将priorityQueue中的队头元素赋值给node,并且从优先级队列中删除移除该元素
            KNNNode node = priorityQueue.remove();

            //获取类别
            String category = node.getCategory();
            System.out.println(node.toString());
            //如果classCount集合已经存在该类别,则该类别数量加1。否则追加该类别,并且数量赋值为1.
            if (classCount.containsKey(category)) {
                classCount.put(category,classCount.get(category) + 1);
            } else {
                classCount.put(category, 1);
            }
        }

        int maxIndex = -1;
        int maxCount = 0;

        //统计在K个元组中,多数类的类别
        Object[] classes = classCount.keySet().toArray();
        for (int i = 0; i < classes.length; i++) {
            if (classCount.get(classes[i]) > maxCount) {
                maxIndex = i;
                maxCount = classCount.get(classes[i]);
            }
            System.out.println("类别" + classes[i] + "出现的个数为:" + classCount.get(classes[i]));
        }
        //返回多数类的类别
        return classes[maxIndex].toString();

    }



}

创建节点类:

package knn;

/**
 * 
 * @author FeiYue
 * KNN节点,用于保存既存元组编号,该元组与测试元组之间的距离以及类别该既存元组的类别
 */
public class KNNNode {
    private  int index; //元组标号
    private double distance; //与测试数据的距离
    private String category; //类别



    public KNNNode(int index, double distance, String category) {

        this.index = index;
        this.distance = distance;
        this.category = category;
    }


    public KNNNode() {

    }


    public int getIndex() {
        return index;
    }
    public void setIndex(int index) {
        this.index = index;
    }
    public double getDistance() {
        return distance;
    }
    public void setDistance(double distance) {
        this.distance = distance;
    }
    public String getCategory() {
        return category;
    }
    public void setCategory(String category) {
        this.category = category;
    }


    @Override
    public String toString() {
        String content = "既存元组编号:" + index + ",与测试数据的距离:" + distance + ",既存元组类别: " + category;
        return content;
    }
}

运行:

package knn;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;

/**
 * 
 * @author FeiYue
 * 运行程序
 */

public class Test {
    public void read(List<List<Double>> datas, String path, boolean flag) {
        try {
            @SuppressWarnings("resource")
            BufferedReader br = new BufferedReader(new FileReader(new File(path)));
            String reader = br.readLine();

            while (reader != null) {
                if (flag) {
                    System.out.println(reader);
                }           
                String t[] = reader.split(" ");
                ArrayList<Double> list = new ArrayList<>();
                for (int i = 0; i < t.length; i++) {
                    list.add(Double.parseDouble(t[i]));
                }
                datas.add(list);
                reader = br.readLine();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    } 

    public static void main(String[] args) {
        Test t = new Test();
        String USER_DIR = System.getProperty("user.dir");
        //既存数据
        String dataFile = USER_DIR + "\\datafile.txt";
        //测试数据
        String testFile = USER_DIR + "\\testfile.txt";

        try {
            List<List<Double>> datas = new ArrayList<List<Double>>();
            List<List<Double>> testDatas = new ArrayList<List<Double>>();
            System.out.println("既存数据:        ");
            t.read(datas, dataFile, true);
            t.read(testDatas, testFile, false);
            System.out.println("\n");
            KNN knn =new KNN();
            for (int i = 0; i < testDatas.size(); i++) {
                System.out.println("********************第" + (i+1) + "个测试元组开始测试***********************");
                List<Double> test = testDatas.get(i);
                System.out.print("测试数据: ");
                for (int j = 0; j < test.size(); j++) {       
                    System.out.print(test.get(j) + " ");
                }
                int c = Math.round(Float.parseFloat((knn.knn(datas, test, 3))));
                /*System.out.println();*/
                System.out.println("类别为: " + c);
                System.out.println("********************第" + (i+1) + "个测试元组结束测试***********************");
                System.out.println("\n\n\n");
            }

        } catch (Exception e) {
            e.printStackTrace();
        }
    }


}

既存数据:
动作次数      亲吻次数      电影类别(1:爱情片,0:武打片)
3                 104                 1
2                 100                 1
1                 81                   1
101             10                   0
99               5                     0
98               2                    0
测试数据:
18              90
55              45
90              10
运行结果:
18              90                      1
55              45                      1
90              10                      0


PythonDemo

#导入numy库
import numpy as np
#导入KNeighborsClassifier
from sklearn.neighbors import KNeighborsClassifier

#创建训练数据
sampleData = np.array([[3, 104], [2, 100], [1, 81], [101, 10], [99, 5], [98, 2]])
#训练数据的分类
category = [1, 1, 1, 0, 0, 0]
#建立KNN模型
neigh = KNeighborsClassifier(n_neighbors=3, algorithm='auto').fit(sampleData, category)
#预测测试数据的类别
print(neigh.predict([[18, 90]]))
print(neigh.predict([[55, 45]]))
print(neigh.predict([[90, 10]]))

函数解析:
KNeighborsClassifier是一个类,它集成了其他的NeighborsBase, KNeighborsMixin,SupervisedIntegerMixin, ClassifierMixin。

init() 初始化函数(构造函数) 它主要有一下几个参数:

  n_neighbors=5
int 型参数 knn算法中指定以最近的几个最近邻样本具有投票权,默认参数为5

  weights=’uniform’
str参数即每个拥有投票权的样本是按什么比重投票,’uniform’表示等比重投票,’distance’表示按距离反比投票,[callable]表示自己定义的一个函数,这个函数接收一个

距离数组,返回一个权值数组。默认参数为‘uniform’

  algrithm=’auto’
str参数 即内部采用什么算法实现。有以下几种选择参数:’ball_tree’:球树、’kd_tree’:kd树、’brute’:暴力搜索、’auto’:自动根据数据的类型和结构选择合适的算法。默认情况下是‘auto’。暴力搜索就不用说了大家都知道。具体前两种树型数据结构哪种好视情况而定。KD树是对依次对K维坐标轴,以中值切分构造的树,每一个节点是一个超矩形,在维数小于20时效率最高–可以参看《统计学习方法》第二章。ball tree 是为了克服KD树高维失效而发明的,其构造过程是以质心C和半径r分割样本空间,每一个节点是一个超球体。一般低维数据用kd_tree速度快,用ball_tree相对较慢。超过20维之后的高维数据用kd_tree效果反而不佳,而ball_tree效果要好,具体构造过程及优劣势的理论大家有兴趣可以去具体学习。

  leaf_size=30
int参数 基于以上介绍的算法,此参数给出了kd_tree或者ball_tree叶节点规模,叶节点的不同规模会影响数的构造和搜索速度,同样会影响储树的内存的大小。具体最优规模是多少视情况而定。

  matric=’minkowski’
str或者距离度量对象 即怎样度量距离。默认是闵氏距离,闵氏距离不是一种具体的距离度量方法,它可以说包括了其他距离度量方式,是其他距离度量的推广,具体各种距离度量只是参数p的取值不同或者是否去极限的不同情况,具体大家可以参考这里,讲的非常详细                           
  p=2
int参数就是以上闵氏距离各种不同的距离参数,默认为2,即欧氏距离。p=1代表曼哈顿距离等等

  metric_params=None
距离度量函数的额外关键字参数,一般不用管,默认为None

  n_jobs=1
int参数 指并行计算的线程数量,默认为1表示一个线程,为-1的话表示为CPU的内核数,也可以指定为其他数量的线程,这里不是很追求速度的话不用管,需要用到的话去看看多线程。

fit()
训练函数,它是最主要的函数。接收参数只有1个,就是训练数据集,每一行是一个样本,每一列是一个属性。它返回对象本身,即只是修改对象内部属性,因此直接调用就可以了,后面用该对象的预测函数取预测自然及用到了这个训练的结果。其实该函数并不是KNeighborsClassifier这个类的方法,而是它的父类SupervisedIntegerMixin继承下来的方法。

predict()
预测函数 接收输入的数组类型测试样本,一般是二维数组,每一行是一个样本,每一列是一个属性返回数组类型的预测结果,如果每个样本只有一个输出,则输出为一个一维数组。如果每个样本的输出是多维的,则输出二维数组,每一行是一个样本,每一列是一维输出。

predict_prob()
基于概率的软判决,也是预测函数,只是并不是给出某一个样本的输出是哪一个值,而是给出该输出是各种可能值的概率各是多少接收参数和上面一样返回参数和上面类似,只是上面该是值的地方全部替换成概率,比如说输出结果又两种选择0或者1,上面的预测函数给出的是长为n的一维数组,代表各样本一次的输出是0还是1.而如果用概率预测函数的话,返回的是n*2的二维数组,每一行代表一个样本,每一行有两个数,分别是该样本输出为0的概率为多少,输出1的概率为多少。而各种可能的顺序是按字典顺序排列,比如先0后1,或者其他情况等等都是按字典顺序排列。

score()
计算准确率的函数,接受参数有3个。 X:接收输入的数组类型测试样本,一般是二维数组,每一行是一个样本,每一列是一个属性。y:X这些预测样本的真实标签,一维数组或者二维数组。sample_weight=None,是一个和X第一位一样长的各样本对准确率影响的权重,一般默认为None.输出为一个float型数,表示准确率。内部计算是按照predict()函数计算的结果记性计算的。其实该函数并不是KNeighborsClassifier这个类的方法,而是它的父类KNeighborsMixin继承下来的方法。

kneighbors()
计算某些测试样本的最近的几个近邻训练样本。接收3个参数。X=None:需要寻找最近邻的目标样本。n_neighbors=None,表示需要寻找目标样本最近的几个最近邻样本,默认为None,需要调用时给出。return_distance=True:是否需要同时返回具体的距离值。返回最近邻的样本在训练样本中的序号。其实该函数并不是KNeighborsClassifier这个类的方法,而是它的父类KNeighborsMixin继承下来的方法。

参考资料:

https://wenku.baidu.com/view/d84cf670a5e9856a561260ce.html
https://wenku.baidu.com/view/d84cf670a5e9856a561260ce.html
https://www.cnblogs.com/xiaotan-code/p/6680438.html

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值