人工智能 - K近邻分类算法的java实现

K近邻分类算法的java实现 手动输入K值

算法介绍:

关于算法的介绍以及实验的要求在上一个博客已经介绍,这里不再赘述。链接:
人工智能 - K近邻分类算法的Python实现

代码:

package KnnYin;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Scanner;
class DisTy implements Comparable<DisTy>{
     Double dis;
     Integer type;

    public DisTy(Double dis, Integer type) {
        this.dis = dis;
        this.type = type;
    }

    @Override
    public int compareTo(DisTy o) {
        return this.dis.compareTo(o.dis);
    }

    @Override
    public String toString() {
        return "DisTy{" +
                "dis=" + dis +
                ", type=" + type +
                '}';
    }
}

public class Knn {
    public static void main(String[] args) {
        File test = new File("D:/iris-data-testing.txt");
        File tran = new File("D:/iris-data-training.txt");
       // int cnt = 0;
        System.out.println("输入k");
        Scanner scs = new Scanner(System.in);
        int ans = scs.nextInt();
        try {
            Scanner sct = new Scanner(test);
            //System.out.println(scn.nextDouble());
            int[] type = new int[31];
            int cnt = 0;
            for (int i = 0; i < 31; i++) {
                Double[] ts = new Double[4];  //花的前四个数据double
                for (int j = 0; j < 4; j++) {
                    ts[j] = sct.nextDouble();
                    //System.out.print(ts[j]+" ");
                }
                type[i] = sct.nextInt();//花的类型
                //System.out.print(type[i]+" ");
                DisTy[] name = new DisTy[120];

                try {
                    Scanner scn = new Scanner(tran);

                    for (int j = 0; j < 120; j++){
                        Double[] tn = new Double[4];
                        for (int k = 0; k < 4; k++) {
                            tn[k] = scn.nextDouble();

                        }
                        double sum = 0;

                        for (int k = 0; k < 4; k++) {
                            sum += Math.pow(tn[k]-ts[k],2);
                        }

                        int h = scn.nextInt();
                        name[j] = new DisTy(Math.sqrt(sum),h);
                    }

                    int count[] = new int[5];
                    for(int j = 1;j < 4;j++) {
                        count[j] = 0;
                    }
                    Arrays.sort(name);
                    for (int j = 0; j < ans; j++) {
                        count[name[j].type]++;
                    }
                    int bestCount = 0,bestType = -1;
                    for(int j = 1;j<4;j++) {
                        if(count[j]>bestCount) {
                            bestCount = count[j];
                            bestType = j;
                        }
                    }

                    if(bestType == type[i]) {
                        cnt = cnt + 1;
                    }

                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
            double r = cnt/31.0;
            System.out.println("正确率为:"+r);

        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    }
}

结果:

在这里插入图片描述
想求出所有K值结果看上一个博客。

思路(自定义数组的一个属性对数组进行排序):

思路和上一个差不多也是读取每一个要test的数据,去和训练集中的每一个数据求欧拉距离,求出的每一个数据的距离和类型放进一个类里面。然后就是根据自定义数组的一个属性对数组进行排序,在这里记录一下,备忘:
首先对类实现Comparable<~>接口,重写compareTo方法:

class DisTy implements Comparable<DisTy>{
     Double dis;
     Integer type;

    public DisTy(Double dis, Integer type) { //构造器
        this.dis = dis;
        this.type = type;
    }
    @Override
    public int compareTo(DisTy o) {   //   重写comparTo方法
        return this.dis.compareTo(o.dis);
    }
  }
}

然后用对数组操作的方法:

Arrays.sort(name);

注意在给自定义数组赋值时,要在new的时候初始化传参

DisTy[] name = new DisTy[120];
...
name[j] = new DisTy(Math.sqrt(sum),h);

这样就完成了排序。

另外一种排序方法就是在sort里面重写比较器:

Arrays.sort(peoples,new Comparator() {
@Override
public int compare(People o1,People o2) {
if(o1.weight==o2.weight){
return o2.height-o1.height;
}else{
return o1.weight-o2.weight;
}
}

});

这样的话类就不用实现Comparable<~>接口。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值