KNN算法实现以及性能测试

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

class Knn {
    static class Sample {
        String label;
        int[] pixels;
    }

    static List<Sample> readFile(String file) throws IOException {
        List<Sample> samples = new ArrayList<Sample>();
        BufferedReader reader = new BufferedReader(new FileReader(file));
        try {
            String line;
            while ((line = reader.readLine()) != null) {
                String[] tokens = line.split("\\s+");
                Sample sample = new Sample();
                sample.label = tokens[0];
                sample.pixels = new int[tokens.length - 1];
                for (int i = 1; i < tokens.length; i++) {
                    sample.pixels[i - 1] = Integer.parseInt(tokens[i]);
                }
                samples.add(sample);
            }
        } finally {
            reader.close();
        }
        return samples;
    }

    private static int distance(int[] a, int[] b) {
        int sum = 0;

        for (int i = 0; i < a.length; i++) {
            sum += (a[i] - b[i]) * (a[i] - b[i]);
        }
        return (int) Math.sqrt(sum);
    }

    static String classify(List<Sample> trainingSet, int[] pixels, int k) {

        TopK TK = new TopK(k);
        for (Sample sample : trainingSet) {
            double dist = distance(sample.pixels, pixels);
            TK.add(TK.new LabelDistance(sample.label, dist));
        }
        return TK.getLabel();
    }
}


import java.util.HashMap;
import java.util.Set;

public class TopK {

    public LabelDistance[] topk;
    int k;

    class LabelDistance {
        String label;
        double distance;

        LabelDistance(String label, double distance) {
            this.label = label;
            this.distance = distance;
        }
    }

    TopK(int k) {
        this.k = k;
        topk = new LabelDistance[k];
        for (int i = 0; i < k; i++) {
            topk[i] = new LabelDistance(null, Double.MAX_VALUE);
        }

    }

    void add(LabelDistance LD) {
        if (LD.distance < topk[k - 1].distance) {

            int i;
            for (i = k - 1; LD.distance < topk[i].distance && i >= 1; i--) {
                topk[i] = topk[i - 1];
            }

            topk[i] = LD;
        }
    }

    String getLabel() {
        String label = null;
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        for (LabelDistance ld : topk) {
            if (map.containsKey(ld.label))
                map.put(ld.label, map.get(ld.label) + 1);
            else
                map.put(ld.label, 1);
        }
        int count = Integer.MIN_VALUE;

        Set<String> set = map.keySet();
        for (String s : set) {
            if (map.get(s) > count) {
                count = map.get(s);
                label = s;
            }
        }

        return label;
    }
}



import java.io.IOException;
import java.util.List;

public class testMain {

    public static void main(String args[]) throws IOException {
        List<Knn.Sample> trainingSet = Knn.readFile("letter.txt");
        List<Knn.Sample> validationSet = Knn.readFile("sum.txt");
        for (int j = 1; j < 21; j++) {
            int numCorrect = 0;
            for (Knn.Sample sample : validationSet) {
                if (Knn.classify(trainingSet, sample.pixels, j).equals(
                        sample.label))
                    numCorrect++;
            }
            System.out.println("k = " + j + "   Accuracy: "
                    + (double) numCorrect / validationSet.size() * 100 + "%");
        }
    }






训练数据: 19900

测试数据:100


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值