java实现K-最近邻算法

package KNN;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;

public class TestKNN {

	/**
	 * @param args
	 */
	public static void main(String[] args) {
		TestKNN t = new TestKNN();
		String datafile = "训练数据1.txt";
		String testfile = "测试数据1.txt";
		try {
			List<List<Double>> datas = new ArrayList<List<Double>>();
			List<List<Double>> testDatas = new ArrayList<List<Double>>();
			t.read(datas, datafile);
			t.read(testDatas, testfile);
			KNN knn = new KNN();
			for (int i = 0; i < testDatas.size(); i++) {
				List<Double> test = testDatas.get(i);
				System.out.print("测试元组: ");
				for (int j = 0; j < test.size(); j++) {
					System.out.print(test.get(j) + " ");
				}
				System.out.print("类别为: ");
				System.out.println(Math.round(Float.parseFloat((knn.knn(datas,
						test, 4))))); // 返回最接近参数的 int
			}
		} catch (Exception e) {
			e.printStackTrace();
		}

	}

	public void read(List<List<Double>> datas, String path) {
		try {
			BufferedReader br = new BufferedReader(new FileReader(
					new File(path)));
			String data = br.readLine();
			List<Double> l = null;
			while (data != null) {
				String t[] = data.split(" ");
				l = new ArrayList<Double>();
				for (int i = 0; i < t.length; i++) {
					l.add(Double.parseDouble(t[i]));
				}
				datas.add(l);
				data = br.readLine();
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

}

package KNN;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

/**
 * KNN算法主体类
 */
public class KNN {
	/**
	 * 设置优先级队列的比较函数,距离越大,优先级越高
	 */
	private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
		public int compare(KNNNode o1, KNNNode o2) {
			if (o1.getDistance() >= o2.getDistance()) {
				return 1;
			} else {
				return 0;
			}
		}
	};

	/**
	 * 获取K个不同的随机数
	 * 
	 * @param k
	 *            随机数的个数
	 * @param max
	 *            随机数最大的范围
	 * @return 生成的随机数数组
	 */
	public List<Integer> getRandKNum(int k, int max) {
		List<Integer> rand = new ArrayList<Integer>(k);
		for (int i = 0; i < k; i++) {
			int temp = (int) (Math.random() * max);
			if (!rand.contains(temp)) {
				rand.add(temp);
			} else {
				i--;
			}
		}
		return rand;
	}

	/**
	 * 计算测试元组与训练元组之前的距离
	 * 
	 * @param d1
	 *            测试元组
	 * @param d2
	 *            训练元组
	 * @return 距离值
	 */
	public double calDistance(List<Double> d1, List<Double> d2) {
		double distance = 0.00;
		for (int i = 0; i < d1.size(); i++) {
			distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
		}
		return Math.sqrt(distance);
	}

	/**
	 * 执行KNN算法,获取测试元组的类别
	 * 
	 * @param datas
	 *            训练数据集
	 * @param testData
	 *            测试元组
	 * @param k
	 *            设定的K值
	 * @return 测试元组的类别
	 */
	public String knn(List<List<Double>> datas, List<Double> testData, int k) {
		ArrayList<KNNNode> alKN = new ArrayList<KNNNode>();
		for (int i = 0; i < k; i++) {
			List<Double> currData = datas.get(i);
			String c = currData.get(currData.size() - 1).toString();
			KNNNode node = new KNNNode(i, calDistance(testData, currData), c);
			alKN.add(node);
		}
		for (int i = k; i < datas.size(); i++) {
			List<Double> t = datas.get(i);
			double distance = calDistance(testData, t);
			for (int j = 0; j < alKN.size(); j++) {
				if ((alKN.get(j).getDistance()) >= distance) {
					alKN.remove(j);
					alKN.add(new KNNNode(i, distance, t.get(t.size() - 1)
							.toString()));
					break;
				}
			}

		}

		return getMostClass(alKN);
	}

	/**
	 * 获取所得到的k个最近邻元组的多数类
	 * 
	 * @param pq
	 *            存储k个最近近邻元组的优先级队列
	 * @return 多数类的名称
	 */
	private String getMostClass(ArrayList<KNNNode> alKN) {
		Map<String, Integer> classCount = new HashMap<String, Integer>();
		for (int i = 0; i < alKN.size(); i++) {
			KNNNode node = alKN.remove(i);
			String c = node.getC();
			if (classCount.containsKey(c)) {
				classCount.put(c, classCount.get(c) + 1);
			} else {
				classCount.put(c, 1);
			}
		}
		int maxIndex = -1;
		int maxCount = 0;
		Object[] classes = classCount.keySet().toArray();
		for (int i = 0; i < classes.length; i++) {
			if (classCount.get(classes[i]) > maxCount) {
				maxIndex = i;
				maxCount = classCount.get(classes[i]);
			}
		}
		return classes[maxIndex].toString();
	}
}

package KNN;

/**
 * KNN结点类,用来存储最近邻的k个元组相关的信息
 */
public class KNNNode {
	private int index; // 元组标号
	private double distance; // 与测试元组的距离
	private String c; // 所属类别

	public KNNNode(int index, double distance, String c) {
		super();
		this.index = index;
		this.distance = distance;
		this.c = c;
	}

	public int getIndex() {
		return index;
	}

	public void setIndex(int index) {
		this.index = index;
	}

	public double getDistance() {
		return distance;
	}

	public void setDistance(double distance) {
		this.distance = distance;
	}

	public String getC() {
		return c;
	}

	public void setC(String c) {
		this.c = c;
	}

}

训练数据:
2.0 1
4.0 0
10.0 1
12.0 0
3.0 1
20.0 0
22.0 1
21.0 0
11.0 1
24.0 0

测试数据:
18.0


结果:

测试元组: 18.0 类别为: 0


参考:http://blog.csdn.net/luowen3405/article/details/6278764

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值