用java实现k-近邻算法分类器

20 篇文章 0 订阅
18 篇文章 0 订阅

完整工程代码下载路径:

https://download.csdn.net/download/luohualiushui1/10949773

k-近邻算法可以说是机器学习里面比较简单的算法。它的概念如下:

如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。

我们可以从资料中查到k-近邻算法的python语言实现代码:

def classifyO(inX,dataSet,labels,k):
    dataSetSize=dataSet.shape[0]
    diffMat=tile(inX,(dataSetSize,1))-dataSet
    sqDiffMat=diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances ** 0.5
    sortedDistIndicies = distances.argaort()
    classCount={}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
    sortedClassCount=sorted(classCount.iteritems(),
    key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

现在我们需要用java来实现该分类器

python的依赖库比较丰富,numpy的矩阵运算也是非常完善,而java则在这个方面比较薄弱。

我们首先要选择java的矩阵处理库,一开始我选用ujmp发现这个库并不够用,然后选择了ejml。

而ejml也还是不够python的numpy完善,有几个方法需要自己实现。例如:

numpy的tile方法,我用java实现如下:

public static DenseMatrix64F tile(DenseMatrix64F input,int length) {
		
		DenseMatrix64F rs = new DenseMatrix64F(length,input.numCols);
		
		for(int i=0;i<length;i++) {
			for(int j=0;j<input.numCols;j++) {
				rs.set(i, j, input.get(0, j));
			}
		}
		
		return rs;
	}

numpy的argsort方法,这个是获取数组排序之后的元素原索引,我用java的实现如下:

	public static int[] argsort(DenseMatrix64F input) {
		
		int[] rs =  new int[input.numRows];
		
		for(int i=0;i<input.numRows;i++){
			rs[i] = i;
		}
		
		for(int i=0;i<input.numRows-1;i++) {
			for(int j=i+1;j<input.numRows;j++) {
				if(input.get(i,0) > input.get(j, 0)) {
					
					double tmp = input.get(j, 0);
					int tmpIndex = rs[j];
					
					input.set(j, 0, input.get(i,0));
					input.set(i, 0, tmp);
					
					rs[j] = rs[i];
					rs[i] = tmpIndex;

				}
			}
		}
		
		return rs;
	}

最后实现关键方法:

	public static String classify(DenseMatrix64F input,DenseMatrix64F his,String [] labels,int k) {
		
		DenseMatrix64F diffs = new DenseMatrix64F(his.numRows,his.numCols); 

        CommonOps.subtract(tile(input,his.numRows), his, diffs);
        
        DenseMatrix64F mult = new DenseMatrix64F(his.numRows,his.numCols); 
        
        CommonOps.elementMult(diffs, diffs, mult);
        
        DenseMatrix64F dis = new DenseMatrix64F(his.numRows,1); 
        
        CommonOps.sumRows(mult, dis);
        
        DenseMatrix64F sqrt = new DenseMatrix64F(his.numRows,1); 
        
        CommonOps.elementPower(dis,0.5, sqrt);
        
        int[] args = argsort(sqrt); 
        
        Map<String,Integer> rs = new HashMap<String,Integer>();
		
		for(int i=0;i<k;i++) {
			
			if(rs.containsKey(labels[args[i]])) {
				int tmp = rs.get(labels[args[i]])+1;
				rs.remove(labels[args[i]]);
				rs.put(labels[args[i]],tmp);
			}else {
				rs.put(labels[args[i]], 1);
			}
		}
		
		String lab = "";
		
		int labNum = 0;
		
		for (Map.Entry<String, Integer> entry : rs.entrySet()){ 
			
			if(entry.getValue() > labNum) {
				labNum = entry.getValue();
				lab = entry.getKey();
			}
		}
		
		return lab;
		
	}

ok,接下来我们初始化矩阵实现:

		DenseMatrix64F datas = new DenseMatrix64F(4,2);
			datas.set(0,0,1.0);
			datas.set(0,1,1.1);
			datas.set(1,0,1.0);
			datas.set(1,1,1.0);
			datas.set(2,0,0);
			datas.set(2,1,0);
			datas.set(3,0,0);
			datas.set(3,1,0.1);
			
		DenseMatrix64F input = new DenseMatrix64F(1,2);
			input.set(0,0,0);
			input.set(0,1,0.1);
		
	    System.out.println(datas);	
	    System.out.println(datas.numRows);	
	    
	    String labels[] = {"A","A","B","C"};
	    
	    
	    System.out.println(classify(input,datas,labels,1));

结果如下:

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

路边草随风

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值