用JAVA实现KNN算法

目录

        一、大体思路:

        二、思路操作:

                三个类:数据类Data、算法类KNN、测试类KNNTest

        三、代码实现:

        四、代码运行:


一、大体思路:

  1. 输入所有已知点
  2. 输入未知点
  3. 计算所有已知点到未知点的欧式距离
  4. 根据距离对所有已知点排序
  5. 选出距离未知点最近的k个点
  6. 计算k个点所在分类出现的频率
  7. 选择频率最大的类别即为未知点的类别

二、思路操作:

        三个类:数据类Data、算法类KNN、测试类KNNTest

Data:

public class Data {
    public static double[][]X_train= {{},{},{},{}};
    public static double[][]X_test= {{},{},{},{}};
    public static int[] y_train= {};
    public static int[] y_test= {};
}

KNNTest:

public class KNNTest {

    public static void main(String[] args) {
        // TODO Auto-generated method stub
        KNN model =new KNN();
        model.fit(Data.X_train, Data.y_train);
        int[] yhat=model.predict(Data.X_test);
        for(int i=0;i<yhat.length;i++)
        {
            System.out.print(yhat[i]-Data.y_test[i]+",");
        }
        System.out.print(model.score(Data.X_test, Data.y_test));
    }

}

KNN:

package knn;

import java.util.Map;

import org.omg.CORBA.PRIVATE_MEMBER;

public class KNN {
    private double[][]X_train;
    private int[] y_train;
    private int K;
    
    public KNN(){}
    public void fit(double[][]X,int[]y) {}
    public int[] predict(double[][]X) {
        private int[] getTopK(double[]X){
            
        }
        private double getL2Distance(double[]a,double[]b) {
            
        }
        private Map<Integer,Integer>getTickCount(int[]indices){
            
        }
    }
    public double score(double[][]X_train,int[]y) {
        return 0;
    }
}

 三、代码实现:

Data:

public class Data {
	public static double[][] X_train = 
			{{ 1.60440089, -0.35116358},
		       {-0.56357758,  1.15774717},
		       {-0.69863061,  0.77620293},
		       {-0.82364866,  0.86822754},
		       { 0.45857236,  0.5017019 },
		       {-1.00392102,  1.15207238},
		       { 0.95709601,  0.30502143},
		       {-0.21911253,  0.49662864},
		       { 0.35766276,  1.0872811 },
		       {-0.25253387,  0.18776697},
		       {-1.15442774,  0.40574243},
		       { 0.11834786,  0.99156023},
		       { 0.55688998,  0.26120887},
		       {-0.62523133,  0.03868576},
		       {-0.01141219,  0.30437635},
		       { 0.30143461,  0.76398572},
		       {-1.02663961,  0.41265823},
		       {-0.97644617,  0.03612864},
		       { 1.5215205 , -0.1258923 },
		       { 2.07463826,  0.51253705},
		       { 0.18911691,  0.64887424},
		       { 1.61652485, -0.29469483},
		       { 0.4630177 ,  0.86392418},
		       { 0.57753017,  0.64564015},
		       { 0.13229471, -0.26032619},
		       { 1.95416454, -0.12850579},
		       { 0.33332321,  0.14006592},
		       { 0.17295579,  0.60135526},
		       { 0.16919147, -0.30895665},
		       {-0.20784982,  1.09043495},
		       { 1.22821377, -0.50119159},
		       {-0.00353234,  0.21487064},
		       { 2.04251223, -0.46074593},
		       { 0.68841773,  1.15068264},
		       { 0.3203754 ,  0.87271389},
		       {-0.70416708,  0.30649274},
		       { 1.79531658, -0.32235591},
		       { 2.02773677,  0.25021451},
		       {-0.66342048,  0.85301441},
		       { 0.81816111,  0.78952806},
		       { 0.8496547 , -0.30507345},
		       { 0.71060915,  0.80990546},
		       {-0.41467068,  0.92254691},
		       {-0.14739828,  0.22556193},
		       { 0.45166927,  1.00185497},
		       { 1.43861749, -0.15796611},
		       { 0.6233932 , -0.69422694},
		       {-1.32109805,  0.51437657},
		       { 0.97332087, -0.70530678},
		       {-0.24648981,  1.09136047},
		       { 0.04167562, -0.07462092},
		       { 0.56116634, -0.0330033 },
		       { 1.30440877, -0.52950917},
		       { 0.1240315 ,  0.67672995},
		       { 0.29809162,  0.82211432},
		       { 1.12344625, -0.09556327},
		       {-0.56365899,  0.8918972 },
		       { 0.27695668,  0.01210816},
		       { 0.30966003,  1.16677531},
		       {-0.57716798,  0.2942259 },
		       {-0.94262451,  0.57258351},
		       { 0.87081342, -0.4366643 },
		       { 1.12820353,  0.4664342 },
		       { 0.80862106,  0.28415372},
		       {-0.15878875,  0.25584465},
		       { 0.62273618, -0.52804644},
		       {-0.27631969,  1.34161045},
		       { 0.48281721, -0.43196099},
		       { 0.93694537,  0.36597075},
		       {-0.22047436,  1.28343927},
		       { 0.74346118,  0.46465633},
		       {-0.3802171 ,  0.88414623},
		       { 0.99716574,  0.35017425},
		       {-1.35462041,  0.28524762},
		       { 0.84248307,  0.55728315},
		       { 0.53148147, -0.27424077},
		       { 0.64426474, -0.36209808},
		       { 1.05840957,  0.51858443},
		       { 1.16443301,  0.01495781},
		       { 1.60172165, -0.37604995},
		       { 0.78362442, -0.25844144},
		       { 1.99150162,  0.02534858},
		       { 1.24491552, -0.5137574 },
		       { 0.3395913 , -0.02223857},
		       {-0.7204608 ,  1.01733354},
		       {-0.46610015,  0.98764879},
		       { 1.48069489, -0.3572808 },
		       { 1.0628775 ,  0.17231496},
		       { 0.55631903, -0.70781481},
		       {-0.38454276,  0.50916381},
		       { 0.88357043, -0.35868845},
		       { 2.04970274,  0.66368306},
		       { 1.89660871,  0.25413209},
		       { 1.40426435, -0.93206382},
		       {-0.93313522,  0.73385959},
		       { 1.26735927, -0.11813675},
		       { 0.6975216 , -0.11832611},
		       { 1.2278091 , -0.64785108},
		       { 0.63918299,  0.96606739},
		       {-0.64018301,  0.75214137},
		       { 1.68734249, -0.6872367 },
		       { 1.51241178,  0.11081331},
		       { 0.1023533 ,  1.09326207},
		       {-0.72894228,  0.44179419},
		       { 1.06739115, -0.38783511},
		       { 0.99633397,  0.1731019 },
		       {-0.25531442,  0.83953933},
		       { 1.4298847 , -0.21080222},
		       { 0.13387667,  0.6944329 },
		       { 2.13330659,  0.11200406},
		       { 2.25536302,  0.02862685},
		       { 1.98373996, -0.11222315},
		       { 0.09048712,  0.0890939 },
		       { 0.15931878, -0.02835184},
		       {-0.13270154,  1.26653562},
		       {-1.088752  , -0.39694315},
		       { 0.99826403, -0.80979075},
		       { 1.0717818 , -0.40141988},
		       {-0.06114159, -0.02921072},
		       {-0.50840939,  0.55259494},
		       {-0.93516752,  0.43520803},
		       { 0.94193283,  0.63204507},
		       {-0.48367053,  0.43679813},
		       { 1.65936346, -0.70351567},
		       { 0.91722632, -0.29657499},
		       { 0.80950246,  0.3505231 },
		       { 0.06721632, -0.1649077 },
		       { 1.06519327, -0.38867949},
		       { 1.65755662, -0.63203157},
		       { 0.98031957, -0.56811367},
		       { 0.37268327,  1.01743002},
		       {-1.07873435,  0.36644163},
		       { 1.02592325,  0.42143427},
		       { 0.88446589, -0.47595401},
		       { 1.6049806 ,  0.13835516},
		       {-0.0557186 ,  0.57286794},
		       { 0.82262609, -0.02317445},
		       { 1.773464  , -0.34102513},
		       { 0.26375409,  0.91508367},
		       {-0.8970433 ,  0.87690996},
		       { 0.45046033,  1.09585861},
		       { 0.23337057,  0.10750568},
		       { 1.32734432, -0.48056888},
		       { 0.73451679,  0.5346233 },
		       { 0.34005355,  0.32486358},
		       { 1.99317676,  0.48903983},
		       { 2.08208438,  0.00715606},
		       {-0.64683386,  0.46971252},
		       { 0.20331   ,  0.19519454},
		       { 1.33653621, -0.18005761},
		       {-1.0873102 ,  0.78128608},
		       { 1.9369961 ,  0.63112161},
		       {-0.33698825,  0.89060661},
		       { 0.87629641,  0.28951999},
		       { 0.71787709, -0.09708361},
		       { 0.29449971, -0.26078938},
		       {-1.13305483,  0.11109962},
		       { 0.01262924, -0.59715374},
		       {-0.12213442,  0.15292037},
		       { 0.87161309,  0.01715969}};
	public static int[] y_train= 
		{1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1,
	    0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1,
	    0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1,
        0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0,
        1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1,
        1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0,
        0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0,
        1, 1, 0, 1, 1, 0};
	
	public static double[][] X_test = 
			   {{ 0.1059936 ,  0.90514125},
		       { 0.63711933, -0.55537183},
		       { 0.75877114,  0.76057045},
		       { 0.17673342,  1.3178874 },
		       { 1.92673019,  0.40817963},
		       { 0.25919429,  1.04104213},
		       { 1.77377397, -0.10513907},
		       { 1.768052  , -0.25443213},
		       { 0.03351023,  0.63113817},
		       { 1.51658699, -0.45069719},
		       {-1.2855903 ,  0.10677262},
		       { 0.2482245 ,  0.7860477 },
		       { 0.09252135, -0.31618454},
		       { 1.96682571,  0.2646737 },
		       {-0.59909677,  0.76903039},
		       { 0.6515764 , -0.40677494},
		       { 0.6779624 ,  0.78024482},
		       { 1.79414416,  0.28323389},
		       { 1.01934807, -0.17993629},
		       {-0.21582418,  1.03521642},
		       { 0.4828491 , -0.21452374},
		       { 0.66388701,  0.94659669},
		       { 1.53296943, -0.36277826},
		       { 2.0436285 ,  0.24563453},
		       {-0.39131894,  0.40925201},
		       { 0.75307594,  0.8526869 },
		       {-0.75191922,  0.63798317},
		       {-0.03857867,  0.0838378 },
		       { 0.88155818,  0.23925957},
		       { 1.82749513, -0.03085446},
		       {-0.70574528,  0.54883003},
		       { 1.45384401,  0.12718529},
		       {-0.07795147,  0.27995261},
		       {-0.70851168,  0.49617855},
		       { 1.43380709,  0.69183071},
		       { 0.66137686, -0.31314104},
		       { 0.25403599, -0.00644002},
		       {-0.52953439,  0.69307316},
		       {-0.99805184,  0.62420544},
		       {-1.07840959,  0.56402523}};
	public static int[] y_test = {0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0,
	                              1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0};
}
KNNTest:

public class KNNTest {

	public static void main(String[] args) {
		
		KNN model = new KNN();
		model.fit(Data.X_train, Data.y_train);
		int[] yhat = model.predict(Data.X_test);
		for(int i=0;i<yhat.length;i++)
		{
			System.out.print((yhat[i] - Data.y_test[i]) + ",");
		}
		System.out.println();
		System.out.print(model.score(Data.X_test, Data.y_test));
	}

}
KNN:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class KNN {
	private double[][] X_train;
	private int[] y_train;
	private int K = 5;
	public void fit(double[][] X,int[] y)
	{
		this.X_train = X;
		this.y_train = y;
	}
	
	public int[] predict(double[][] X)
	{
		int[] result = new int[X.length];
		for(int i=0;i<X.length;i++)
		{
			result[i] = this.predict(X[i]);
		}
		return result;
	}
	
	private int predict(double[] x)
	{
		int[] indices = this.getTopK(x);
		Map<Integer, Integer> map = new HashMap<>();
		for(int i=0;i<indices.length;i++)
		{
			int id = indices[i];
			if(map.containsKey(this.y_train[id]))
			{
				map.put(this.y_train[id], map.get(this.y_train[id])+1);
			}
			else 
			{
				map.put(this.y_train[id], 1);
			}
		}
		
		List<Map.Entry<Integer, Integer>> list = new ArrayList<>(map.entrySet());
		list.sort((a,b)->b.getValue() - a.getValue());
		
		return list.get(0).getKey();		
	}
	
	private int[] getTopK(double[] x) 
	{
		int[] indices = new int[this.K];
		double[] ds = new double[this.K];
		for(int i=0;i<ds.length;i++)
		{
			ds[i]=Double.MAX_VALUE; 
		}
		
		for(int i=0;i<X_train.length;i++)
		{
			double d = getL2Distance(X_train[i], x);
			for(int j=0;j<ds.length;j++)
			{
				if(d<ds[j])
				{					
					for(int k=this.K-1;k>j;k--)
					{
						ds[k]=ds[k-1];
						indices[k] = indices[k-1];
					}
					ds[j] = d;
					indices[j] = i;					
					break;
				}
			}
		}
		return indices;
	}
	
	private double getL2Distance(double[] a,double[] b)
	{
		double sum = 0;
		for(int i=0;i<a.length;i++)
		{
			sum += (a[i]-b[i])*(a[i]-b[i]);
		}
		return Math.sqrt(sum);
	}
	
	public double score(double[][] X,int[] y)
	{
		int[] yhat = this.predict(X);
		int errorCount = 0;
		for(int i=0;i<yhat.length;i++)
		{
			if(yhat[i]!=y[i])
			{
				errorCount++;
			}
		}
		return (yhat.length - errorCount) / (double)yhat.length;
	}
}

四、代码运行:

注:总结了一下老师上课讲的KNN算法内容,希望对其他小伙伴有所帮助。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值