完整工程代码下载路径:
https://download.csdn.net/download/luohualiushui1/10949773
k-近邻算法可以说是机器学习里面比较简单的算法。它的概念如下:
如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。
我们可以从资料中查到k-近邻算法的python语言实现代码:
def classifyO(inX,dataSet,labels,k):
dataSetSize=dataSet.shape[0]
diffMat=tile(inX,(dataSetSize,1))-dataSet
sqDiffMat=diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argaort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
sortedClassCount=sorted(classCount.iteritems(),
key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
现在我们需要用java来实现该分类器
python的依赖库比较丰富,numpy的矩阵运算也是非常完善,而java则在这个方面比较薄弱。
我们首先要选择java的矩阵处理库,一开始我选用ujmp发现这个库并不够用,然后选择了ejml。
而ejml也还是不够python的numpy完善,有几个方法需要自己实现。例如:
numpy的tile方法,我用java实现如下:
public static DenseMatrix64F tile(DenseMatrix64F input,int length) {
DenseMatrix64F rs = new DenseMatrix64F(length,input.numCols);
for(int i=0;i<length;i++) {
for(int j=0;j<input.numCols;j++) {
rs.set(i, j, input.get(0, j));
}
}
return rs;
}
numpy的argsort方法,这个是获取数组排序之后的元素原索引,我用java的实现如下:
public static int[] argsort(DenseMatrix64F input) {
int[] rs = new int[input.numRows];
for(int i=0;i<input.numRows;i++){
rs[i] = i;
}
for(int i=0;i<input.numRows-1;i++) {
for(int j=i+1;j<input.numRows;j++) {
if(input.get(i,0) > input.get(j, 0)) {
double tmp = input.get(j, 0);
int tmpIndex = rs[j];
input.set(j, 0, input.get(i,0));
input.set(i, 0, tmp);
rs[j] = rs[i];
rs[i] = tmpIndex;
}
}
}
return rs;
}
最后实现关键方法:
public static String classify(DenseMatrix64F input,DenseMatrix64F his,String [] labels,int k) {
DenseMatrix64F diffs = new DenseMatrix64F(his.numRows,his.numCols);
CommonOps.subtract(tile(input,his.numRows), his, diffs);
DenseMatrix64F mult = new DenseMatrix64F(his.numRows,his.numCols);
CommonOps.elementMult(diffs, diffs, mult);
DenseMatrix64F dis = new DenseMatrix64F(his.numRows,1);
CommonOps.sumRows(mult, dis);
DenseMatrix64F sqrt = new DenseMatrix64F(his.numRows,1);
CommonOps.elementPower(dis,0.5, sqrt);
int[] args = argsort(sqrt);
Map<String,Integer> rs = new HashMap<String,Integer>();
for(int i=0;i<k;i++) {
if(rs.containsKey(labels[args[i]])) {
int tmp = rs.get(labels[args[i]])+1;
rs.remove(labels[args[i]]);
rs.put(labels[args[i]],tmp);
}else {
rs.put(labels[args[i]], 1);
}
}
String lab = "";
int labNum = 0;
for (Map.Entry<String, Integer> entry : rs.entrySet()){
if(entry.getValue() > labNum) {
labNum = entry.getValue();
lab = entry.getKey();
}
}
return lab;
}
ok,接下来我们初始化矩阵实现:
DenseMatrix64F datas = new DenseMatrix64F(4,2);
datas.set(0,0,1.0);
datas.set(0,1,1.1);
datas.set(1,0,1.0);
datas.set(1,1,1.0);
datas.set(2,0,0);
datas.set(2,1,0);
datas.set(3,0,0);
datas.set(3,1,0.1);
DenseMatrix64F input = new DenseMatrix64F(1,2);
input.set(0,0,0);
input.set(0,1,0.1);
System.out.println(datas);
System.out.println(datas.numRows);
String labels[] = {"A","A","B","C"};
System.out.println(classify(input,datas,labels,1));
结果如下: