机器学习入门算法及其java实现-KNN算法

1、算法基本原理:

  • 对于一个新点 X0(x0,y0) ,它的分类 y0 由离它最近的k个点的类别决定;
  • 其中训练集为 T{(x1,y1),(x2,y2),...,(xn,yn)} ,离 X0(x0,y0) 最近的K个点根据分类决策规则(如多数表决)决定 X0(x0,y0) 的类别 y0 :
    y0=argmaxξjxjNk(x)I(yi=cj),i=1,2,...,N;j=1,2,...,K

    2、距离度量:
    特征空间中两个实例点的距离是两个实例点相似程度的反映。K近邻模型的特征空间一般是n维实数向量空间 Rn 。使用的距离是欧氏距离或其他距离,如更一般的 Lp 距离和Minkowski距离。
    设特征空间 χ 是n维实数向量空间 Rn , xi , xj χ , xi=(x(1)i,x(2)i,...,x(n)i)T , xj=(x(1)j,x(2)j,...,x(n)j)T , xi , xj Lp 的距离定义为:
    Lp(xi,xj)=(l=1n|xlixlj|p)1p
    这里 p1 ,当 p=2 时,称为欧式距离,即
    L2(xi,xj)=(l=1n|xlixlj|2)
    p=1 ,称为曼哈顿距离,即:
    L1(xi,xj)=l=1n|xlixlj|2

    p= 时,它时各个坐标距离的最大值,即
    L(xi,xj)=maxl|xlixli|

    3、K值的选择:
    K值的选择会对K近邻法产生重大的影响。
    如果选择较小的K值,相当于用较小的领域中的训练实例进行预测,“学习”的近似误差会减小,但学习的估计误差会增大,预测结果会对近邻的实例点非常敏感。
    如果选择较大的K值,可以减少估计误差,但是学习的近似误差会增大。
    在应用中,一般取一个较小的值,采用交叉验证的办法取最优的k值。
    4、分类决策规则:
    K近邻法中分类决策往往是多数表决,即由输入实例的K个近邻的训练实例中的多数类决定输入实例的类。
    使用平台:eclipse,R
    实验数据:人工数据
    相关程序:
    使用R生成数据,使用Java处理数据:
X<-matrix(1:50,nrow=25,ncol=2)
Y<-matrix(0,nrow=25,ncol=1)
for (i in 1:25){
   if (runif(1)<0.5){
   X[i,1]=exp(runif(1))*1.3  
   X[i,2]=exp(runif(1))*3
   Y[i]=0
}
   else{
   X[i,1]=exp(runif(1))*3
   X[i,2]=exp(runif(1))*3
   Y[i]=1
}
}
data<-cbind(Y,X)
write.table(data,"C:/Users/CJH/Desktop/R程序运行/KNNtest.txt",row.names=FALSEcol.names
=,FALSE)
#数据生成

KNNtrain<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNtrain.txt",head=FALSE)
KNNtest<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNtest.txt",head=FALSE)
class1<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer1.txt",head=FALSE)
class2<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer2.txt",head=FALSE)
class3<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer3.txt",head=FALSE)
class4<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer4.txt",head=FALSE)
class5<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer5.txt",head=FALSE)
class6<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer6.txt",head=FALSE)
class10<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer10.txt",head=FALSE)
class30<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer30.txt",head=FALSE)
#输入数据及结果

KNNtrain<-data.frame(KNNtrain)
names(KNNtrain)<-c("class","x","y")
KNNtrain$class<-factor(KNNtrain$class)
library(ggplot2)
ggplot(data=KNNtrain,aes(x=KNNtrain$x,y=KNNtrain$y,shape=KNNtrain$class,color=KNNtrain$class))+
geom_point(size=3)+labs(title="TrainData",x="x",y="y")

KNNtest<-data.frame(KNNtest)
names(KNNtest)<-c("class","x","y")
KNNtest$class<-factor(KNNtest$class)
ggplot(data=KNNtest,aes(x=KNNtest$x,y=KNNtest$y,shape=KNNtest$class,color=KNNtest$class))+
geom_point(size=3)+labs(title="TestData",x="x",y="y")
#生成图像

Eclipse程序:
package KNN;

import java.io.*;
import java.util.*;

public class InputData{
    public void loadData(double [][]x,double[]y,String trainfile)throws IOException{
       File file = new File("C:\\Users\\CJH\\Desktop\\R程序运行",trainfile);
       RandomAccessFile raf= new RandomAccessFile(file,"r");
       StringTokenizer tokenizer;   
       int i=0,j=0;
       while(true){
           String line = raf.readLine();
           if(line==null)break;
           tokenizer= new StringTokenizer(line);
           y[i]=Double.parseDouble(tokenizer.nextToken());
           while(tokenizer.hasMoreTokens()){
           x[i][j]=Double.parseDouble(tokenizer.nextToken());
           j++;
           }
           j=0;i++;
       }
       raf.close();
    }

}
//输入数据

package KNN;

public class CrossValidation {
    private int k;
    private int n;
    private int m;
    private int n1;
    public int getK(){
        return k;
        }
    public int getN(){
        return n;
        }
    public int getM(){
        return m;
        }
    public int getN1(){
        return n1;
        }
    public void setK(int b,int a,int t,int p){
         k=b;
         n=a;
         m=t;
         n1=p;
        }
}
//原始参数设置

package KNN;

public class KNN {

public double[] y(double[][] X,double[] Y,double[][] newpoints,int k,double[] c){
    double[][] distance=new double[newpoints.length][X.length];
    double[] y=new double[newpoints.length];
    int[][] rank=new int[newpoints.length][k];
    for (int i=0;i<newpoints.length;i++){
        for(int j=0;j<X.length;j++){
        distance[i][j]=euclid(newpoints[i],X[j]);
       }
    }
    for(int i=0;i<newpoints.length;i++){
        rank[i]=Rank(distance[i],k);
        y[i]=Y[rank[i][0]];
        int t1=count(rank[i],Y,Y[rank[i][0]]);
        for(int j=0;j<c.length;j++){
            int t=count(rank[i],Y,c[j]);
            if(t>t1){
                y[i]=c[j];
            }
        }
    }
    return y;
}

private int count(int[] rank, double[] y, double d) {
    int count=0;
    for (int i=0;i<rank.length;i++){
        if(y[rank[i]]==d){
            count=count+1;
        }
    }
    return count;
}

private int[] Rank(double[] distance,int k) {
    int[] Rank=new int[k];
    int[] temp1= new int[distance.length];
    for(int i=0;i<distance.length;i++){
        temp1[i]=i;
    }
    for (int i=0;i<k;i++){
               double temp=distance[i];
               int temp2=temp1[i];
               for(int j=i+1;j<distance.length;j++){
                   if(temp>distance[j]){
                       temp2=temp1[j];
                       temp1[j]=temp1[i];
                       temp1[i]=temp2;
                       temp=distance[j];
                   }
               if(temp1[i]!=i){
                   distance[j]=distance[i];
                   distance[i]=temp;
               }
        }
    }
               for (int i=0;i<k;i++){
                   Rank[i]=temp1[i];
               }
    return Rank;
}

private double euclid(double[] ds, double[] ds2) {
    double distance=0;
    if (ds.length==ds2.length){
        for(int i=0;i<ds.length;i++){
            distance=distance+Math.pow(ds[i]-ds2[i], 2);
        }
        distance=Math.sqrt(distance);
    }
    else{
        distance=0;
    }
    return distance;
}
}
//KNN算法


package KNN;

import java.io.IOException;

public class KNNmain{
    public static void main(String[] args) throws IOException {
        CrossValidation kvalue=new CrossValidation();
        kvalue.setK(30,100,2,25);
        int k=kvalue.getK();
        int n=kvalue.getN();
        int m=kvalue.getM();
        int n1=kvalue.getN1();
        double[] y=new double[n1];
        double[] Y1=new double[n1];
        double[][] newpoints=new double[n1][m];
        InputData ori=new InputData();
        InputData op=new InputData();
        double[][] X=new double[n][m];
        double[] Y=new double[n];
        ori.loadData(X, Y, "KNNtrain.txt");
        op.loadData(newpoints,Y1,"KNNtest.txt");
        KNN Kdata=new KNN();
        double[] c=new double[2];
        c[0]=0;
        c[1]=1;
        y=Kdata.y(X,Y,newpoints,k,c);
        double temp=0;
        for (int i=0;i<n1;i++){
            if(y[i]!=Y1[i]){
                temp=temp+1;
            }
            System.out.println(y[i]+" ");
        }
        System.out.println(temp/(double)newpoints.length);
    }
}
//主程序,输出分类结果及交叉验证结果

仿真数据:
这里写图片描述

训练数据

这里写图片描述

测试数据
k值k=1k=2k=3k=4k=5k=6k=10k=30
训练集错误率0.040.040.60.40.560.240.240.30

由上表,该数据中K值取1或者2是最好的。

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值