import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
class Knn {
static class Sample {
String label;
int[] pixels;
}
static List<Sample> readFile(String file) throws IOException {
List<Sample> samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(file));
try {
String line;
while ((line = reader.readLine()) != null) {
String[] tokens = line.split("\\s+");
Sample sample = new Sample();
sample.label = tokens[0];
sample.pixels = new int[tokens.length - 1];
for (int i = 1; i < tokens.length; i++) {
sample.pixels[i - 1] = Integer.parseInt(tokens[i]);
}
samples.add(sample);
}
} finally {
reader.close();
}
return samples;
}
private static int distance(int[] a, int[] b) {
int sum = 0;
for (int i = 0; i < a.length; i++) {
sum += (a[i] - b[i]) * (a[i] - b[i]);
}
return (int) Math.sqrt(sum);
}
static String classify(List<Sample> trainingSet, int[] pixels, int k) {
TopK TK = new TopK(k);
for (Sample sample : trainingSet) {
double dist = distance(sample.pixels, pixels);
TK.add(TK.new LabelDistance(sample.label, dist));
}
return TK.getLabel();
}
}
import java.util.HashMap;
import java.util.Set;
public class TopK {
public LabelDistance[] topk;
int k;
class LabelDistance {
String label;
double distance;
LabelDistance(String label, double distance) {
this.label = label;
this.distance = distance;
}
}
TopK(int k) {
this.k = k;
topk = new LabelDistance[k];
for (int i = 0; i < k; i++) {
topk[i] = new LabelDistance(null, Double.MAX_VALUE);
}
}
void add(LabelDistance LD) {
if (LD.distance < topk[k - 1].distance) {
int i;
for (i = k - 1; LD.distance < topk[i].distance && i >= 1; i--) {
topk[i] = topk[i - 1];
}
topk[i] = LD;
}
}
String getLabel() {
String label = null;
HashMap<String, Integer> map = new HashMap<String, Integer>();
for (LabelDistance ld : topk) {
if (map.containsKey(ld.label))
map.put(ld.label, map.get(ld.label) + 1);
else
map.put(ld.label, 1);
}
int count = Integer.MIN_VALUE;
Set<String> set = map.keySet();
for (String s : set) {
if (map.get(s) > count) {
count = map.get(s);
label = s;
}
}
return label;
}
}
import java.io.IOException;
import java.util.List;
public class testMain {
public static void main(String args[]) throws IOException {
List<Knn.Sample> trainingSet = Knn.readFile("letter.txt");
List<Knn.Sample> validationSet = Knn.readFile("sum.txt");
for (int j = 1; j < 21; j++) {
int numCorrect = 0;
for (Knn.Sample sample : validationSet) {
if (Knn.classify(trainingSet, sample.pixels, j).equals(
sample.label))
numCorrect++;
}
System.out.println("k = " + j + " Accuracy: "
+ (double) numCorrect / validationSet.size() * 100 + "%");
}
}
训练数据: 19900
测试数据:100