KNN算法的一个简单例子,有注释
Compute.java
knn.java
Compute.java
public class Compute {
public float integrate(float[][] train, String[] flag, float[][] test, String[] flagOrigin, int k, String[] tag)
{
float[] distance;
int[] index;
String ans = null;
int count = 0;
System.out.println("Testing...");
for(int i=0; i<test.length; i++)
{
//求出训练集到test[i]的距离向量
distance = getDistance(train, test[i]);
//对这个距离向量进行排序
index = sort(distance);
//
ans = getAnswer(index, k, flag, tag);
System.out.println("The class predicted is "+ans + "\t" + "and the class priginal is " + flagOrigin[i]);
if(!ans.equals(flagOrigin[i]))
{
count++;
}
}
return (float)count/test.length;
}
/**
* 计算训练集中所有点到测试点的距离
* @param train
* @param test
* @return
*/
public float[] getDistance(float[][] train, float[] test)
{
int row = train.length; //训练集长度
int col = train[0].length; //数据维度
float[] distance = new float[row];
for(int i=0; i<row; i++)
{
double temp = 0;
for(int j=0; j<col; j++)
{
temp += Math.pow((test[j]-train[i][j]),2);
}
distance[i] = (float)Math.sqrt(temp);
}
return distance;
}
/**
* 排序数组,返回下标数组
* @param dis
* @return
*/
public int[] sort(float[] dis)
{
int length = dis.length ;
int[] index = new int[length] ; //记录排序后对应的下标
for(int i=0;i<length;i++)
{
index[i] = i ;
}
for(int i=0;i<length;i++)
{
for(int j=i;j<length;j++)
{
if (dis[i]>dis[j])
{
float temp ;
temp = dis[i];
dis[i] = dis[j] ;
dis[j] = temp ;
int t ;
t = index[i] ;
index[i] = index[j] ;
index[j] = t ;
}
}
}
return index ;
}
/**
* 统计距离最近的前k个训练集的类别分布,返回类别数最高的作为判定的类别
* @param index 排序后的距离数组下标
* @param k
* @param flag 训练集所属分类数组
* @param tag 类别数组
* @return
*/
public String getAnswer(int[] index, int k, String[] flag, String[] tag)
{
//统计每个类别的数量
int[] count = new int[tag.length];
//取距离最近的前k个
for(int i=0; i<k; i++)
{
for(int j=0; j<tag.length; j++)
{
if(flag[index[i]].equals(tag[j]))
{
count[j]++;
}
}
}
//获得统计数量最大的类别的下标
int ind = max(count);
//返回这个类别
return tag[ind];
}
/**
* 找出数组中最大值的下标
* @param count
* @return
*/
public int max(int[] count)
{
int max = count[0];
int index = 0;
for(int i=1; i<count.length; i++)
{
if(max<count[i])
{
max = count[i];
index = i;
}
}
return index;
}
}
knn.java
public class Knn {
/**
* @param args
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
Compute com = new Compute();
//训练集,长度18
float[][] train = {
{60,18.4f},
{85.5f,16.8f},
{64.5f,21.6f},
{61.5f,20.8f},
{87,23.6f},
{82.8f,22.4f},
{69,20},
{93,20.8f},
{51,22},
{75,19.6f},
{64.8f,17.2f},
{43.2f,20.4f},
{84,17.6f},
{49.2f,20.4f},
{47.4f,16.4f},
{43,18.8f},
{51,14},
{63,14.8f},
};
//训练集的类别归属,长度18
String[] flag={"1","1","1","1","1","1","1","1","1",
"2","2","2","2","2","2","2","2","2"} ;
//测试数据集,长度6
float[][] test = {
{32,19.2f},
{108,17.6f},
{81,20},
{52.8f,20.8f},
{59.4f,16},
{66,18.4f}};
//测试数据集的原本归属
String[] flagOrigin = {"1","1","1","2","2","2"};
//类别
String[] tag = {"1","2"};
int k = 13;
float errorRate = com.integrate(train, flag, test, flagOrigin, k, tag);
System.out.println("correct rate is "+ (1-errorRate));
}
}