KNN邻近算法

KNN邻近算法是一种经典的机器学习和数据挖掘方法。它通过寻找测试样本的K个最近邻训练样本,根据多数类别来预测测试样本的类别。算法流程包括数据预处理、设定K值、维护优先级队列以及计算多数类。可以使用MATLAB或JAVA实现KNN算法。
摘要由CSDN通过智能技术生成

KNN邻近算法(K nearest neighbor)


知识储备:
训练样本是用来训练学习机的,测试样本是学习机要识别的对象。

KNN就是选择某个测试样本的K个最近邻(训练样本,即已经知道分类的数据)X中多数所属的类别作为X的类别。


  • 简介
    1. 准备数据,对数据进行预处理
    2. 选用合适的数据结构存储训练数据和测试元组
    3. 设定参数,如k
    4. 维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
    5. 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L 与优先级队列中的最大距离Lmax
    6. 进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
    7. 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。
    8. 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。

  • 算法MATLAB实现
function target=KNN(in,out,test,k)
% in:       training samples data,n*d matrix
% out:      training samples' class label,n*1
% test:     testing data
% target:   class label given by knn
% k:        the number of neighbors
ClassLabel=unique(out);
c=length(ClassLabel);
n=size(in,1);
% target=zeros(size(test,1),1);
dist=zeros(size(in,1),1);
for j=1:size(test,1)
    cnt=zeros(c,1);
    for i=1:n
        dist(i)=norm(in(i,:)-test(j,:));
    end
    [d,index]=sort(dist);
    for i=1:k
        ind=find(ClassLabel==out(index(i)));
        cnt(ind)=cnt(ind)+1;
    end
    [m,ind]=max(cnt);
    target(j)=ClassLabel(ind);
end
  • 算法JAVA实现

package KNN;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

/**
 * KNN算法主体类
 * 
 * @author June
 * @qq 544348879
 * @mail 544348879@qq.com
 * @data 2015.12.10
 */
public class KNN {
    /**
     * 设置优先级队列的比较函数,距离越大,优先级越高
     */
    private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
        public int compare(KNNNode o1, KNNNode o2) {
            if (o1.getDistance() >= o2.getDistance()) {
                return 1;
            } else {
                return 0;
            }
        }
    };

    /**
     * 获取K个不同的随机数
     * 
     * @param k
     *            随机数的个数
     * @param max
     *            随机数最大的范围
     * @return 生成的随机数数组
     */
    public List<Integer> getRandKNum(int k, int max) {
        List<Integer> rand = new ArrayList<Integer>(k);
        for (int i = 0; i < k; i++) {
            int temp = (int) (Math.random() * max);
            if (!rand.contains(temp)) {
                rand.add(temp);
            } else {
                i--;
            }
        }
        return rand;
    }

    /**
     * 计算测试元组与训练元组之前的距离
     * 
     * @param d1
     *            测试元组
     * @param d2
     *            训练元组
     * @return 距离值
     */
    public double calDistance(List<Double> d1, List<Double> d2) {
        double distance = 0.00;
        for (int i = 0; i < d1.size(); i++) {
            distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
        }
        return distance;
    }

    /**
     * 执行KNN算法,获取测试元组的类别
     * 
     * @param datas
     *            训练数据集
     * @param testData
     *            测试元组
     * @param k
     *            设定的K值
     * @return 测试元组的类别
     */
    public String knn(List<List<Double>> datas, List<Double> testData, int k) {
        PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);
        List<Integer> randNum = getRandKNum(k, datas.size());
        for (int i = 0; i < k; i++) {
            int index = randNum.get(i);
            List<Double> currData = datas.get(index);
            String c = currData.get(currData.size() - 1).toString();
            KNNNode node = new KNNNode(index, calDistance(testData, currData),
                    c);
            pq.add(node);
        }
        for (int i = 0; i < datas.size(); i++) {
            List<Double> t = datas.get(i);
            double distance = calDistance(testData, t);
            KNNNode top = pq.peek();
            if (top.getDistance() > distance) {
                pq.remove();
                pq
                        .add(new KNNNode(i, distance, t.get(t.size() - 1)
                                .toString()));
            }
        }

        return getMostClass(pq);
    }

    /**
     * 获取所得到的k个最近邻元组的多数类
     * 
     * @param pq
     *            存储k个最近近邻元组的优先级队列
     * @return 多数类的名称
     */
    private String getMostClass(PriorityQueue<KNNNode> pq) {
        Map<String, Integer> classCount = new HashMap<String, Integer>();
        for (int i = 0; i < pq.size(); i++) {
            KNNNode node = pq.remove();
            String c = node.getC();
            if (classCount.containsKey(c)) {
                classCount.put(c, classCount.get(c) + 1);
            } else {
                classCount.put(c, 1);
            }
        }
        int maxIndex = -1;
        int maxCount = 0;
        Object[] classes = classCount.keySet().toArray();
        for (int i = 0; i < classes.length; i++) {
            if (classCount.get(classes[i]) > maxCount) {
                maxIndex = i;
                maxCount = classCount.get(classes[i]);
            }
        }
        return classes[maxIndex].toString();
    }
}
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值