Java实现KNN(Weka)
import java.util.Random;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.lazy.IBk;
public class KNN {
public static void main(String[] args) throws Exception {
Instances data = DataSource.read("dataset/colon.arff");
data.setClassIndex(data.numAttributes() - 1);
data.randomize(new Random());
double acc = KFoldFitness(data);
System.out.println(acc);
}
private static double KFoldFitness(Instances train) {
train.randomize(new Random());
int cvNum = 10;
double acc = 0;
for(int i = 0; i < cvNum; i++) {
Instances trainData = train.trainCV(cvNum, i);
Instances testData = train.testCV(cvNum, i);
double[][] dist = getDistance(trainData, testData);
int[][] knnIdx = getKNNIndex(dist, 3);
double[] predictLabel = getPredLabel(knnIdx, trainData);
double accCv = 0;
for (int j = 0; j < testData.numInstances(); j++) {
if (predictLabel[j] == testData.instance(j).classValue()) {
accCv += 1;
}
}
acc += accCv / testData.numInstances();
}
return acc / cvNum;
}
private static double[][] getDistance(Instances train, Instances test) {
double[][] dist = new double[test.numInstances()][train.numInstances()];
for (int i = 0; i < test.numInstances(); i++) {
for (int j = 0; j < train.numInstances(); j++) {
double diff = 0;
for (int k = 0; k < test.instance(i).numAttributes() - 1; k++) {
diff += (test.instance(i).value(k) - train.instance(j).value(k)) * (test.instance(i).value(k) - train.instance(j).value(k));
}
dist[i][j] = diff;
}
}
return dist;
}
private static int[][] getKNNIndex(double[][] dist, int knnK) {
int[][] knnIdx = new int[dist.length][knnK];
int idx;
for (int i = 0; i < dist.length; i++) {
for (int j = 0; j < knnK; j++) {
idx = 0;
for (int k = 0; k < dist[i].length; k++) {
if (dist[i][k] < dist[i][idx]) {
idx = k;
}
}
knnIdx[i][j] = idx;
dist[i][idx] = Double.POSITIVE_INFINITY;
}
}
return knnIdx;
}
private static double[] getPredLabel(int[][] knnIdx, Instances train) {
double[] y = new double[knnIdx.length];
double[][] knnLabel = new double[knnIdx.length][knnIdx[0].length];
for (int i = 0; i < knnIdx.length; i++) {
for (int j = 0; j < knnIdx[i].length; j++) {
knnLabel[i][j] = train.instance(knnIdx[i][j]).classValue();
}
}
double[] labels = new double[train.numClasses()];
for (int i = 0; i < train.numClasses(); i++) {
labels[i] = Double.parseDouble(train.classAttribute().value(i));
}
int[][] voteNum = new int[knnIdx.length][train.numClasses()];
for (int i = 0; i < knnLabel.length; i++) {
for (int j = 0; j < knnLabel[i].length; j++) {
for (int k = 0; k < train.numClasses(); k++) {
if (knnLabel[i][j] == k) {
voteNum[i][k] += 1;
}
}
}
}
for (int i = 0; i < voteNum.length; i++) {
int maxIdx = 0;
for (int j = 0; j < voteNum[i].length; j++) {
if (voteNum[i][maxIdx] < voteNum[i][j]) {
maxIdx = j;
}
}
y[i] = maxIdx;
}
return y;
}
private static double KNNAcc(Instances train, Instances test) {
IBk cls = new IBk();
cls.setKNN(3);
double[][] conf_matrix = new double[train.numClasses()][train.numClasses()];
try {
cls.buildClassifier(train);
for(int i = 0; i < test.numInstances(); ++i) {
int pred = (int)cls.classifyInstance(test.instance(i));
conf_matrix[(int)test.instance(i).classValue()][pred]++;
}
} catch (Exception var6) {
System.out.println("Error");
}
double acc = 0;
int sum = 0;
for(int i = 0; i < conf_matrix.length; i++) {
acc += conf_matrix[i][i];
for(int j = 0; j < conf_matrix[i].length; j++) {
sum += conf_matrix[i][j];
}
}
acc = acc / sum;
return acc;
}
}