weka knn算法改进与实现

本文在weka下,主要使用高斯函数加权,选取最优K值进行优化。你也可以参考网上文档,将如下文的KNN_lsh.java复制到某一目录并进行相关设置,进而在weka gui中测试改进。


文件目录:

这里写图片描述

/weka_test/src/cug/lsh/KNN_lsh.java 如下:

package cug.lsh;

import weka.classifiers.*;
import weka.core.*;
import java.util.*;

@SuppressWarnings("serial")
public class KNN_lsh extends Classifier {

	private Instances m_Train;
	private int m_kNN;
	
	public void setM_kNN(int m_kNN) {
		this.m_kNN = m_kNN;
	}
	
	public void buildClassifier(Instances data) throws Exception {
		m_Train = new Instances(data);	
		
	}

	public double[] distributionForInstance(Instance instance) throws Exception {

		Instances instances= findNeighbors(instance, m_kNN);
		return computeDistribution(instances, instance);
	}
	
	private Instances findNeighbors(Instance instance, int kNN) {
		double distance;	
		List<HasDisInstances> neighborlist = new LinkedList<>();
		
		for (int i = 0; i < m_Train.numInstances(); i++) {
			Instance trainInstance = m_Train.instance(i);
			distance = distance(instance, trainInstance);
			HasDisInstances hasDisInstances=new HasDisInstances(distance,trainInstance);
			
			if(i==0 || (i<kNN-1 && neighborlist.get(neighborlist.size()-1).distance<distance))
				neighborlist.add(hasDisInstances);
			else{
				for (int j = 0; j < kNN && j<neighborlist.size(); j++) {
					if(distance<neighborlist.get(j).distance){
						neighborlist.add(j, hasDisInstances);
						break;
					}
				}
			}
		}
		
		int min=Math.min(kNN, neighborlist.size());
		Instances instances=new Instances(m_Train,min);
		for(int i=0;i<min;i++){
			instances.add(neighborlist.get(i).instance);
		}
		return instances;
	}

	private double distance(Instance first, Instance second) {

		double distance = 0;
		for (int i = 0; i < m_Train.numAttributes(); i++) {
			if (i == m_Train.classIndex())
				continue;
			if((int)first.value(i)!=(int)second.value(i)){
		        distance+=1;
		      }
//			//此处修改距离计算公式
//			distance+=(second.value(i)-first.value(i))*(second.value(i)-first.value(i));//欧基米德尔公式
//			distance+=second.value(i)*Math.log(second.value(i)/first.value(i));最大熵
//			distance+=Math.pow((second.value(i)-first.value(i)), 2)/first.value(i);//卡方距离
		}
//		distance=Math.sqrt(distance);
		return distance;
	}

	private double[] computeDistribution(Instances data, Instance instance) throws Exception {
		
	    double[] prob=new double[data.numClasses()];

	    for (int i=0;i<data.numInstances();i++){
	      int classVal=(int)data.instance(i).classValue();
	      double x=distance(instance, data.instance(i));
	      prob[classVal] +=1+Math.exp(-x*x/0.18);//c=0.3
	    }
		Utils.normalize(prob);
		return prob;
	}

	private class HasDisInstances{
		double distance;
		Instance instance;
		public HasDisInstances(double distance, Instance instance) {
			this.distance = distance;
			this.instance = instance;
		}
	}
}

/weka_test/src/cug/lsh/KNN_lsh_use.java(主函数) 如下:

package cug.lsh;

import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;

public class KNN_lsh_use {
	public static void main(String[] args) throws Exception {
		Instances train=DataSource.read("E:/DataLearing/data/credit-g.arff");		
        train.setClassIndex(train.numAttributes()-1);  
        
        
        int size=(int) (train.numInstances()*0.2);//构造测试集
        Instances test = new Instances(train,size);
        test.setClassIndex(test.numAttributes()-1);
        for (int i = 0; i < size; i++) {
        	test.add(train.instance(i));
			train.delete(i);
		}     
        
        KNN_lsh classifier=new KNN_lsh();       
      //计算最佳k值
        int optiK=0;
        int prob=0;//临时变量,正确个数
        for (int m_kNN = 3; m_kNN < Math.sqrt(train.numInstances())+3  && m_kNN<=20; m_kNN++) {
//        	long oldTime=System.currentTimeMillis();
        	classifier.setM_kNN(m_kNN);
            classifier.buildClassifier(train);
    
            int count=0;         
            for (int i = 0; i < test.numInstances(); i++){ 
                if (classifier.classifyInstance(test.instance(i)) == test.instance(i).classValue())            
                    count++;   
            }
            if(count>prob){
            	optiK=m_kNN;
            	prob=count;
            }
//            long newTime=System.currentTimeMillis();
//            System.out.println(1.0*count/test.numInstances()+","+m_kNN+","+0.001*(newTime-oldTime));
		}
        
        System.out.println(1.0*prob/test.numInstances()+","+optiK);
        
	}
}

  • 3
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值