目录
一、大体思路:
- 输入所有已知点
- 输入未知点
- 计算所有已知点到未知点的欧式距离
- 根据距离对所有已知点排序
- 选出距离未知点最近的k个点
- 计算k个点所在分类出现的频率
- 选择频率最大的类别即为未知点的类别
二、思路操作:
三个类:数据类Data、算法类KNN、测试类KNNTest
Data:
public class Data {
public static double[][]X_train= {{},{},{},{}};
public static double[][]X_test= {{},{},{},{}};
public static int[] y_train= {};
public static int[] y_test= {};
}
KNNTest:
public class KNNTest {
public static void main(String[] args) {
// TODO Auto-generated method stub
KNN model =new KNN();
model.fit(Data.X_train, Data.y_train);
int[] yhat=model.predict(Data.X_test);
for(int i=0;i<yhat.length;i++)
{
System.out.print(yhat[i]-Data.y_test[i]+",");
}
System.out.print(model.score(Data.X_test, Data.y_test));
}}
KNN:
package knn;
import java.util.Map;
import org.omg.CORBA.PRIVATE_MEMBER;
public class KNN {
private double[][]X_train;
private int[] y_train;
private int K;
public KNN(){}
public void fit(double[][]X,int[]y) {}
public int[] predict(double[][]X) {
private int[] getTopK(double[]X){
}
private double getL2Distance(double[]a,double[]b) {
}
private Map<Integer,Integer>getTickCount(int[]indices){
}
}
public double score(double[][]X_train,int[]y) {
return 0;
}
}
三、代码实现:
Data:
public class Data {
public static double[][] X_train =
{{ 1.60440089, -0.35116358},
{-0.56357758, 1.15774717},
{-0.69863061, 0.77620293},
{-0.82364866, 0.86822754},
{ 0.45857236, 0.5017019 },
{-1.00392102, 1.15207238},
{ 0.95709601, 0.30502143},
{-0.21911253, 0.49662864},
{ 0.35766276, 1.0872811 },
{-0.25253387, 0.18776697},
{-1.15442774, 0.40574243},
{ 0.11834786, 0.99156023},
{ 0.55688998, 0.26120887},
{-0.62523133, 0.03868576},
{-0.01141219, 0.30437635},
{ 0.30143461, 0.76398572},
{-1.02663961, 0.41265823},
{-0.97644617, 0.03612864},
{ 1.5215205 , -0.1258923 },
{ 2.07463826, 0.51253705},
{ 0.18911691, 0.64887424},
{ 1.61652485, -0.29469483},
{ 0.4630177 , 0.86392418},
{ 0.57753017, 0.64564015},
{ 0.13229471, -0.26032619},
{ 1.95416454, -0.12850579},
{ 0.33332321, 0.14006592},
{ 0.17295579, 0.60135526},
{ 0.16919147, -0.30895665},
{-0.20784982, 1.09043495},
{ 1.22821377, -0.50119159},
{-0.00353234, 0.21487064},
{ 2.04251223, -0.46074593},
{ 0.68841773, 1.15068264},
{ 0.3203754 , 0.87271389},
{-0.70416708, 0.30649274},
{ 1.79531658, -0.32235591},
{ 2.02773677, 0.25021451},
{-0.66342048, 0.85301441},
{ 0.81816111, 0.78952806},
{ 0.8496547 , -0.30507345},
{ 0.71060915, 0.80990546},
{-0.41467068, 0.92254691},
{-0.14739828, 0.22556193},
{ 0.45166927, 1.00185497},
{ 1.43861749, -0.15796611},
{ 0.6233932 , -0.69422694},
{-1.32109805, 0.51437657},
{ 0.97332087, -0.70530678},
{-0.24648981, 1.09136047},
{ 0.04167562, -0.07462092},
{ 0.56116634, -0.0330033 },
{ 1.30440877, -0.52950917},
{ 0.1240315 , 0.67672995},
{ 0.29809162, 0.82211432},
{ 1.12344625, -0.09556327},
{-0.56365899, 0.8918972 },
{ 0.27695668, 0.01210816},
{ 0.30966003, 1.16677531},
{-0.57716798, 0.2942259 },
{-0.94262451, 0.57258351},
{ 0.87081342, -0.4366643 },
{ 1.12820353, 0.4664342 },
{ 0.80862106, 0.28415372},
{-0.15878875, 0.25584465},
{ 0.62273618, -0.52804644},
{-0.27631969, 1.34161045},
{ 0.48281721, -0.43196099},
{ 0.93694537, 0.36597075},
{-0.22047436, 1.28343927},
{ 0.74346118, 0.46465633},
{-0.3802171 , 0.88414623},
{ 0.99716574, 0.35017425},
{-1.35462041, 0.28524762},
{ 0.84248307, 0.55728315},
{ 0.53148147, -0.27424077},
{ 0.64426474, -0.36209808},
{ 1.05840957, 0.51858443},
{ 1.16443301, 0.01495781},
{ 1.60172165, -0.37604995},
{ 0.78362442, -0.25844144},
{ 1.99150162, 0.02534858},
{ 1.24491552, -0.5137574 },
{ 0.3395913 , -0.02223857},
{-0.7204608 , 1.01733354},
{-0.46610015, 0.98764879},
{ 1.48069489, -0.3572808 },
{ 1.0628775 , 0.17231496},
{ 0.55631903, -0.70781481},
{-0.38454276, 0.50916381},
{ 0.88357043, -0.35868845},
{ 2.04970274, 0.66368306},
{ 1.89660871, 0.25413209},
{ 1.40426435, -0.93206382},
{-0.93313522, 0.73385959},
{ 1.26735927, -0.11813675},
{ 0.6975216 , -0.11832611},
{ 1.2278091 , -0.64785108},
{ 0.63918299, 0.96606739},
{-0.64018301, 0.75214137},
{ 1.68734249, -0.6872367 },
{ 1.51241178, 0.11081331},
{ 0.1023533 , 1.09326207},
{-0.72894228, 0.44179419},
{ 1.06739115, -0.38783511},
{ 0.99633397, 0.1731019 },
{-0.25531442, 0.83953933},
{ 1.4298847 , -0.21080222},
{ 0.13387667, 0.6944329 },
{ 2.13330659, 0.11200406},
{ 2.25536302, 0.02862685},
{ 1.98373996, -0.11222315},
{ 0.09048712, 0.0890939 },
{ 0.15931878, -0.02835184},
{-0.13270154, 1.26653562},
{-1.088752 , -0.39694315},
{ 0.99826403, -0.80979075},
{ 1.0717818 , -0.40141988},
{-0.06114159, -0.02921072},
{-0.50840939, 0.55259494},
{-0.93516752, 0.43520803},
{ 0.94193283, 0.63204507},
{-0.48367053, 0.43679813},
{ 1.65936346, -0.70351567},
{ 0.91722632, -0.29657499},
{ 0.80950246, 0.3505231 },
{ 0.06721632, -0.1649077 },
{ 1.06519327, -0.38867949},
{ 1.65755662, -0.63203157},
{ 0.98031957, -0.56811367},
{ 0.37268327, 1.01743002},
{-1.07873435, 0.36644163},
{ 1.02592325, 0.42143427},
{ 0.88446589, -0.47595401},
{ 1.6049806 , 0.13835516},
{-0.0557186 , 0.57286794},
{ 0.82262609, -0.02317445},
{ 1.773464 , -0.34102513},
{ 0.26375409, 0.91508367},
{-0.8970433 , 0.87690996},
{ 0.45046033, 1.09585861},
{ 0.23337057, 0.10750568},
{ 1.32734432, -0.48056888},
{ 0.73451679, 0.5346233 },
{ 0.34005355, 0.32486358},
{ 1.99317676, 0.48903983},
{ 2.08208438, 0.00715606},
{-0.64683386, 0.46971252},
{ 0.20331 , 0.19519454},
{ 1.33653621, -0.18005761},
{-1.0873102 , 0.78128608},
{ 1.9369961 , 0.63112161},
{-0.33698825, 0.89060661},
{ 0.87629641, 0.28951999},
{ 0.71787709, -0.09708361},
{ 0.29449971, -0.26078938},
{-1.13305483, 0.11109962},
{ 0.01262924, -0.59715374},
{-0.12213442, 0.15292037},
{ 0.87161309, 0.01715969}};
public static int[] y_train=
{1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1,
0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1,
0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1,
0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0,
1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1,
1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0,
1, 1, 0, 1, 1, 0};
public static double[][] X_test =
{{ 0.1059936 , 0.90514125},
{ 0.63711933, -0.55537183},
{ 0.75877114, 0.76057045},
{ 0.17673342, 1.3178874 },
{ 1.92673019, 0.40817963},
{ 0.25919429, 1.04104213},
{ 1.77377397, -0.10513907},
{ 1.768052 , -0.25443213},
{ 0.03351023, 0.63113817},
{ 1.51658699, -0.45069719},
{-1.2855903 , 0.10677262},
{ 0.2482245 , 0.7860477 },
{ 0.09252135, -0.31618454},
{ 1.96682571, 0.2646737 },
{-0.59909677, 0.76903039},
{ 0.6515764 , -0.40677494},
{ 0.6779624 , 0.78024482},
{ 1.79414416, 0.28323389},
{ 1.01934807, -0.17993629},
{-0.21582418, 1.03521642},
{ 0.4828491 , -0.21452374},
{ 0.66388701, 0.94659669},
{ 1.53296943, -0.36277826},
{ 2.0436285 , 0.24563453},
{-0.39131894, 0.40925201},
{ 0.75307594, 0.8526869 },
{-0.75191922, 0.63798317},
{-0.03857867, 0.0838378 },
{ 0.88155818, 0.23925957},
{ 1.82749513, -0.03085446},
{-0.70574528, 0.54883003},
{ 1.45384401, 0.12718529},
{-0.07795147, 0.27995261},
{-0.70851168, 0.49617855},
{ 1.43380709, 0.69183071},
{ 0.66137686, -0.31314104},
{ 0.25403599, -0.00644002},
{-0.52953439, 0.69307316},
{-0.99805184, 0.62420544},
{-1.07840959, 0.56402523}};
public static int[] y_test = {0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0,
1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0};
}
KNNTest:
public class KNNTest {
public static void main(String[] args) {
KNN model = new KNN();
model.fit(Data.X_train, Data.y_train);
int[] yhat = model.predict(Data.X_test);
for(int i=0;i<yhat.length;i++)
{
System.out.print((yhat[i] - Data.y_test[i]) + ",");
}
System.out.println();
System.out.print(model.score(Data.X_test, Data.y_test));
}
}
KNN:
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class KNN {
private double[][] X_train;
private int[] y_train;
private int K = 5;
public void fit(double[][] X,int[] y)
{
this.X_train = X;
this.y_train = y;
}
public int[] predict(double[][] X)
{
int[] result = new int[X.length];
for(int i=0;i<X.length;i++)
{
result[i] = this.predict(X[i]);
}
return result;
}
private int predict(double[] x)
{
int[] indices = this.getTopK(x);
Map<Integer, Integer> map = new HashMap<>();
for(int i=0;i<indices.length;i++)
{
int id = indices[i];
if(map.containsKey(this.y_train[id]))
{
map.put(this.y_train[id], map.get(this.y_train[id])+1);
}
else
{
map.put(this.y_train[id], 1);
}
}
List<Map.Entry<Integer, Integer>> list = new ArrayList<>(map.entrySet());
list.sort((a,b)->b.getValue() - a.getValue());
return list.get(0).getKey();
}
private int[] getTopK(double[] x)
{
int[] indices = new int[this.K];
double[] ds = new double[this.K];
for(int i=0;i<ds.length;i++)
{
ds[i]=Double.MAX_VALUE;
}
for(int i=0;i<X_train.length;i++)
{
double d = getL2Distance(X_train[i], x);
for(int j=0;j<ds.length;j++)
{
if(d<ds[j])
{
for(int k=this.K-1;k>j;k--)
{
ds[k]=ds[k-1];
indices[k] = indices[k-1];
}
ds[j] = d;
indices[j] = i;
break;
}
}
}
return indices;
}
private double getL2Distance(double[] a,double[] b)
{
double sum = 0;
for(int i=0;i<a.length;i++)
{
sum += (a[i]-b[i])*(a[i]-b[i]);
}
return Math.sqrt(sum);
}
public double score(double[][] X,int[] y)
{
int[] yhat = this.predict(X);
int errorCount = 0;
for(int i=0;i<yhat.length;i++)
{
if(yhat[i]!=y[i])
{
errorCount++;
}
}
return (yhat.length - errorCount) / (double)yhat.length;
}
}
四、代码运行:
注:总结了一下老师上课讲的KNN算法内容,希望对其他小伙伴有所帮助。