KNN邻近算法(K nearest neighbor)
知识储备:
训练样本是用来训练学习机的,测试样本是学习机要识别的对象。
KNN就是选择某个测试样本的K个最近邻(训练样本,即已经知道分类的数据)X中多数所属的类别作为X的类别。
- 简介
- 准备数据,对数据进行预处理
- 选用合适的数据结构存储训练数据和测试元组
- 设定参数,如k
- 维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
- 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L 与优先级队列中的最大距离Lmax
- 进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
- 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。
- 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。
- 算法MATLAB实现
function target=KNN(in,out,test,k)
% in: training samples data,n*d matrix
% out: training samples' class label,n*1
% test: testing data
% target: class label given by knn
% k: the number of neighbors
ClassLabel=unique(out);
c=length(ClassLabel);
n=size(in,1);
% target=zeros(size(test,1),1);
dist=zeros(size(in,1),1);
for j=1:size(test,1)
cnt=zeros(c,1);
for i=1:n
dist(i)=norm(in(i,:)-test(j,:));
end
[d,index]=sort(dist);
for i=1:k
ind=find(ClassLabel==out(index(i)));
cnt(ind)=cnt(ind)+1;
end
[m,ind]=max(cnt);
target(j)=ClassLabel(ind);
end
- 算法JAVA实现
package KNN;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
/**
* KNN算法主体类
*
* @author June
* @qq 544348879
* @mail 544348879@qq.com
* @data 2015.12.10
*/
public class KNN {
/**
* 设置优先级队列的比较函数,距离越大,优先级越高
*/
private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
public int compare(KNNNode o1, KNNNode o2) {
if (o1.getDistance() >= o2.getDistance()) {
return 1;
} else {
return 0;
}
}
};
/**
* 获取K个不同的随机数
*
* @param k
* 随机数的个数
* @param max
* 随机数最大的范围
* @return 生成的随机数数组
*/
public List<Integer> getRandKNum(int k, int max) {
List<Integer> rand = new ArrayList<Integer>(k);
for (int i = 0; i < k; i++) {
int temp = (int) (Math.random() * max);
if (!rand.contains(temp)) {
rand.add(temp);
} else {
i--;
}
}
return rand;
}
/**
* 计算测试元组与训练元组之前的距离
*
* @param d1
* 测试元组
* @param d2
* 训练元组
* @return 距离值
*/
public double calDistance(List<Double> d1, List<Double> d2) {
double distance = 0.00;
for (int i = 0; i < d1.size(); i++) {
distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
}
return distance;
}
/**
* 执行KNN算法,获取测试元组的类别
*
* @param datas
* 训练数据集
* @param testData
* 测试元组
* @param k
* 设定的K值
* @return 测试元组的类别
*/
public String knn(List<List<Double>> datas, List<Double> testData, int k) {
PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);
List<Integer> randNum = getRandKNum(k, datas.size());
for (int i = 0; i < k; i++) {
int index = randNum.get(i);
List<Double> currData = datas.get(index);
String c = currData.get(currData.size() - 1).toString();
KNNNode node = new KNNNode(index, calDistance(testData, currData),
c);
pq.add(node);
}
for (int i = 0; i < datas.size(); i++) {
List<Double> t = datas.get(i);
double distance = calDistance(testData, t);
KNNNode top = pq.peek();
if (top.getDistance() > distance) {
pq.remove();
pq
.add(new KNNNode(i, distance, t.get(t.size() - 1)
.toString()));
}
}
return getMostClass(pq);
}
/**
* 获取所得到的k个最近邻元组的多数类
*
* @param pq
* 存储k个最近近邻元组的优先级队列
* @return 多数类的名称
*/
private String getMostClass(PriorityQueue<KNNNode> pq) {
Map<String, Integer> classCount = new HashMap<String, Integer>();
for (int i = 0; i < pq.size(); i++) {
KNNNode node = pq.remove();
String c = node.getC();
if (classCount.containsKey(c)) {
classCount.put(c, classCount.get(c) + 1);
} else {
classCount.put(c, 1);
}
}
int maxIndex = -1;
int maxCount = 0;
Object[] classes = classCount.keySet().toArray();
for (int i = 0; i < classes.length; i++) {
if (classCount.get(classes[i]) > maxCount) {
maxIndex = i;
maxCount = classCount.get(classes[i]);
}
}
return classes[maxIndex].toString();
}
}