1、数据实体类
训练数据TrainRecord和测试数据TestRecord均需要继承Record数据实体类。其中TrainRecord多了一个保存距离的distance成员变量,TestRecord多了一个保存预测类别的predictLabel成员变量。
基类:
public class Record {
public double[] attributes;
public int classLabel;
Record(double[] attributes, int classLabel){
this.attributes = new double[attributes.length];
for (int i = 0; i < attributes.length; i++) {
this.attributes[i] = attributes[i];
}
this.classLabel = classLabel;
}
}
2、计算距离类
所有计算距离类都要实现Metric接口,实现getDistance获取距离函数。下面给出最普通的欧氏距离代码。
public class EuclideanDistance implements Metric {
@Override
public double getDistance(Record s, Record e) {
assert s.attributes.length == e.attributes.length : "s and e are different types of records!";
int numOfAttributes = s.attributes.length;
double sum2 = 0;
for(int i = 0; i < numOfAttributes; i ++){
sum2 += Math.pow(s.attributes[i] - e.attributes[i], 2);
}
return Math.sqrt(sum2);
}
}
3、找到K个最近邻
KNN算法核心的就是计算距离并返回最核心的K个近邻,首先定义一个保存K近邻的数组。
TrainRecord[] neighbors = new TrainRecord[K];
其次,计算测试实例与所有训练实例使用距离公式计算距离,使用neighbors数组记录最近的K个近邻。
for(index = K; index < NumOfTrainingSet; index ++){
trainingSet[index].distance = metric.getDistance(trainingSet[index], testRecord);
//找出距离最远的实例将其替换掉
int maxIndex = 0;
for(int i = 1; i < K; i ++){
if(neighbors[i].distance > neighbors[maxIndex].distance)
maxIndex = i;
}
if(neighbors[maxIndex].distance > trainingSet[index].distance)
neighbors[maxIndex] = trainingSet[index];
}
4、分类
在分类时,针对k个近邻,距离越远的权值越小。通过投票计算所有类别对应的值,返回最大值对应的类别。
for(int index = 0;index < num; index ++) {
TrainRecord temp = neighbors[index];
int key = temp.classLabel;
if (!map.containsKey(key))
map.put(key, 1 / temp.distance);
else {
double value = map.get(key);
value += 1 / temp.distance;
map.put(key, value);
}
}
double maxSimilarity = 0;
int returnLabel = -1;
Set<Integer> labelSet = map.keySet();
Iterator<Integer> it = labelSet.iterator();
while(it.hasNext()){
int label = it.next();
double value = map.get(label);
if(value > maxSimilarity){
maxSimilarity = value;
returnLabel = label;
}
}
结语:计算公式可以进行改进,不单单使用欧氏距离。同时加权的权重可以进行调整,使其更精确。