public class KnnTest { public static void readFileToList(String path, List<List<Double>> list) { BufferedReader br = null; try { br = new BufferedReader(new FileReader(path)); while (br.ready()) { String line = br.readLine(); if (line.trim().isEmpty()) { continue; } String[] tokens = line.split(" "); List<Double> box = new ArrayList<Double>(); for (String num : tokens) { box.add(Double.parseDouble(num)); } list.add(box); } } catch (IOException ex) { ex.printStackTrace(); } } public static void main(String[] args) { int length = 2; String dataFile = "data.txt"; String testFile = "test.txt"; KNN knn = new KNN(); try { List<List<Double>> dataList = new ArrayList<List<Double>>(); List<List<Double>> testList = new ArrayList<List<Double>>(); readFileToList(dataFile, dataList); readFileToList(testFile, testList); for (List<Double> test : testList) { for (Double d : test) { System.out.print(d + " "); } String category = knn.knn(dataList, test, length); System.out.println(Math.round(Float.parseFloat(category))); } } catch (Exception ex) { ex.printStackTrace(); } } } class KNN { private static Comparator<Node> comparator = new Comparator<Node>() { public int compare(Node n1, Node n2) { if (n1.getDistans() > n2.getDistans()) { return 1; } return 0; } }; private int[] getRankNumbers(int n, int max) { int[] result = new int[n]; int current = 0; back: for (int i = 0; i < n; i++) { current = (int) (Math.random() * max); for (int j = 0; j < i; j++) { if (current == result[j]) { i--; continue back; } } result[i] = current; } return result; } public String knn(List<List<Double>> example, List<Double> test, int k) { PriorityQueue<Node> pq = new PriorityQueue<Node>(k, comparator); int[] rand = getRankNumbers(k, example.size()); for (int i = 0; i < k; i++) { List<Double> list = example.get(rand[i]); String category = list.get(list.size() - 1).toString(); Node node = new Node(rand[i], calDistans(test, list), category); pq.add(node); } for (int i = 0; i < example.size(); i++) { List<Double> list = example.get(i); double distans = calDistans(test, list); Node node = pq.peek(); if (node.getDistans() > distans) { pq.remove(); pq.add(new Node(i, distans, list.get(list.size() - 1).toString())); } } return getMostCategory(pq); } private String getMostCategory(PriorityQueue<Node> pq) { Map<String, Integer> rankMapping = new HashMap<String, Integer>(pq.size(), 1); for (int i = 0; i < pq.size(); i++) { Node node = pq.remove(); String category = node.getCategory(); if (rankMapping.containsKey(category)) { rankMapping.put(category, rankMapping.get(category) + 1); } else { rankMapping.put(category, 1); } } int index = -1; int count = 0; Object[] data = rankMapping.keySet().toArray(); for (int i = 0; i < data.length; i++) { if (rankMapping.get(data[i]) > count) { index = i; count = rankMapping.get(data[i]); } } return data[index].toString(); } public double calDistans(List<Double> list1, List<Double> list2) { double result = 0.00; for (int i = 0; i < list1.size(); i++) { result += (list1.get(i) - list2.get(i)) * (list1.get(i) - list2.get(i)); } return result; } static class Node { private int index; private double distans; private String category; public Node(int index, double distans, String category) { this.index = index; this.distans = distans; this.category = category; } public int getIndex() { return index; } public void setIndex(int index) { this.index = index; } public double getDistans() { return distans; } public void setDistans(double distans) { this.distans = distans; } public String getCategory() { return category; } public void setCategory(String category) { this.category = category; } } }
转载于:https://www.cnblogs.com/rilley/archive/2012/09/18/2690098.html