HUNST 数据挖掘课设 《实验五 KNN算法设计与应用》
KNN算法设计与应用
一、实验内容
1.实验要求
2. 实验原理
K最近邻算法是一种常用的监督学习算法,它可以用于分类和回归问题。对于给定的测试样本,KNN 算法使用距离度量来衡量样本之间的相似度。常用的距离度量方法包括欧氏距离、曼哈顿距离、闵可夫斯基距离等,距离越小表示样本越相似。
3.程序流程图
二、代码
1. 实现思路
定义了一个Person类
int id;
String name;
double height;
String category;
double heightDifference;
文件读取,从data.txt中读取需要的数据。内容如下:
private static List<Person> readData(String fileName) throws Exception {
List<Person> person = new ArrayList<>();
BufferedReader br = new BufferedReader(new FileReader(fileName));
String line = br.readLine(); // 读取第一行标题行
while ((line = br.readLine()) != null) {
// 使用正则表达式匹配连续的空格
String[] parts = line.split("\\s+");
// 直接取整数部分
int id = Integer.parseInt(parts[0]);
String name = parts[1].trim();
double height = Double.parseDouble(parts[2].trim());
String category = parts[3].trim();
Person pers = new Person(id, name, height, category, 0);
person.add(pers);
}
br.close();
return person;
}
1.计算测试记录(unknownPerson)和数据集中每个元组的身高差,并将其存储在 heightDifference 属性中。
2.根据计算得到的身高差,选择数据集中距离测试记录最近的 k 个元组,并将其存储在 nearestNeighbors 列表中。
//前 k 个人加入到 nearestNeighbors 列表中,以便后续的分类过程中使用。
List<Person> nearestNeighbors = new ArrayList<>();
for (int i=0;i<person.size();i++){
Person pers=person.get(i);
Double difference= pers.height-unknownPerson.height;
// 使用 DecimalFormat 进行舍入保留两位小数
DecimalFormat df = new DecimalFormat("#.##");
pers.heightDifference = Double.parseDouble(df.format(Math.abs(difference)));
System.out.println((i+1)+" "+pers.heightDifference);
}
System.out.println();
for (int i = 0; i < k; i++) {
Person pers = person.get(i);
nearestNeighbors.add(pers);
}
对于每个元素,获取其身高差值,并与已知的 k 个最近邻的身高差值进行比较。
如果当前元素的身高差值小于已知的 k 个最近邻中的最大差值(即当前最远的邻居),则将当前元素替换为最远的邻居,以保持 k 个最近邻的正确性。
for (int i = k; i < person.size(); i++) {
Person pers = person.get(i);
double maxDifference = 0;
int maxIndex = 0;
for (int j = 0; j < k; j++) {
if (nearestNeighbors.get(j).heightDifference > maxDifference) {
maxDifference = nearestNeighbors.get(j).heightDifference ;
maxIndex = j;
}
}
if (pers.heightDifference < maxDifference) {
nearestNeighbors.set(maxIndex, pers);
}
}
多数投票法(Majority Voting)。在 k-最近邻算法中,通过找到最近的 k 个邻居,并统计它们所属的类别,然后选择出现次数最多的类别作为预测结果。
// 找到出现次数最多的类别
String predictedCategory = "";
int maxCount = 0;
for (Map.Entry<String, Integer> entry : categoryCounts.entrySet()) {
String category = entry.getKey();
int count = entry.getValue();
if (count > maxCount) {
predictedCategory = category;
maxCount = count;
}
}
2. 完整代码
- Person.java
/**
* Created by 23222 on 2023/12/12.
*/
public class Person
{
int id;
String name;
double height;
String category;
double heightDifference;
public Person(int id, String name, double height, String category, double heightDifference) {
this.id = id;
this.name = name;
this.height = height;
this.category = category;
this.heightDifference = heightDifference;
}
@Override
public String toString() {
return "Person{" +
"id=" + id +
", name='" + name + '\'' +
", height=" + height +
", category='" + category + '\'' +
", heightDifference=" + heightDifference +
'}';
}
}
- kNN.java
import java.io.BufferedReader;
import java.io.FileReader;
import java.text.DecimalFormat;
import java.util.*;
/**
* Created by 23222 on 2023/12/12.
*/
public class kNN {
public static void main(String[] args) throws Exception {
// 读取数据集
List<Person> dataset = readData("../5-KNN/dataset.txt");
Scanner scanner = new Scanner(System.in);
int id = dataset.size();
System.out.println("Enter the name:");
String name = scanner.next(); // Assuming you want to read a single word
System.out.println("Enter the height:");
double height = scanner.nextDouble();
System.out.println("Enter the K:");
int k = scanner.nextInt();
Person unknownPerson = new Person(id, name, height, null, 0);
String predictedCategory = classify(k,unknownPerson, dataset);
unknownPerson.category=predictedCategory;
System.out.println("K="+k);
System.out.println("预测结果:" + unknownPerson.name+" "+unknownPerson.height+"----->"+unknownPerson.category);
}
private static String classify(int k, Person unknownPerson, List<Person> person) {
// 计算测试记录和数据集中每个元组的身高差
List<Person> nearestNeighbors = new ArrayList<>();
for (int i=0;i<person.size();i++){
Person pers=person.get(i);
Double difference= pers.height-unknownPerson.height;
// 使用 DecimalFormat 进行舍入保留两位小数
DecimalFormat df = new DecimalFormat("#.##");
pers.heightDifference = Double.parseDouble(df.format(Math.abs(difference)));
System.out.println((i+1)+" "+pers.heightDifference);
}
System.out.println();
for (int i = 0; i < k; i++) {
Person pers = person.get(i);
nearestNeighbors.add(pers);
}
for (int i = k; i < person.size(); i++) {
Person pers = person.get(i);
double maxDifference = 0;
int maxIndex = 0;
for (int j = 0; j < k; j++) {
if (nearestNeighbors.get(j).heightDifference > maxDifference) {
maxDifference = nearestNeighbors.get(j).heightDifference ;
maxIndex = j;
}
}
if (pers.heightDifference < maxDifference) {
nearestNeighbors.set(maxIndex, pers);
}
}
for (Person neighbor : nearestNeighbors) {
System.out.println(neighbor);
}
// 统计k个元组中高矮中等类别的出现次数
Map<String, Integer> categoryCounts = new HashMap<>();
for (Person neighbor : nearestNeighbors) {
String category = neighbor.category;
categoryCounts.put(category, categoryCounts.getOrDefault(category, 0) + 1);
}
for (Map.Entry<String, Integer> entry : categoryCounts.entrySet()) {
System.out.println(entry.getKey() + " : " + entry.getValue());
}
// 找到出现次数最多的类别
String predictedCategory = "";
int maxCount = 0;
for (Map.Entry<String, Integer> entry : categoryCounts.entrySet()) {
String category = entry.getKey();
int count = entry.getValue();
if (count > maxCount) {
predictedCategory = category;
maxCount = count;
}
}
return predictedCategory;
}
private static List<Person> readData(String fileName) throws Exception {
List<Person> person = new ArrayList<>();
BufferedReader br = new BufferedReader(new FileReader(fileName));
String line = br.readLine(); // 读取第一行标题行
while ((line = br.readLine()) != null) {
// 使用正则表达式匹配连续的空格
String[] parts = line.split("\\s+");
// 直接取整数部分
int id = Integer.parseInt(parts[0]);
String name = parts[1].trim();
double height = Double.parseDouble(parts[2].trim());
String category = parts[3].trim();
Person pers = new Person(id, name, height, category, 0);
person.add(pers);
}
br.close();
return person;
}
}