HUNST 数据挖掘课设 《实验五 KNN算法设计与应用》

HUNST 数据挖掘课设 《实验五 KNN算法设计与应用》

KNN算法设计与应用

一、实验内容

1.实验要求在这里插入图片描述
2. 实验原理

K最近邻算法是一种常用的监督学习算法,它可以用于分类和回归问题。对于给定的测试样本,KNN 算法使用距离度量来衡量样本之间的相似度。常用的距离度量方法包括欧氏距离、曼哈顿距离、闵可夫斯基距离等,距离越小表示样本越相似。

3.程序流程图

在这里插入图片描述

二、代码

1. 实现思路

定义了一个Person类

    int id;
    String name;
    double height;
    String category;
    double heightDifference;

文件读取,从data.txt中读取需要的数据。内容如下:

在这里插入图片描述

 private static List<Person> readData(String fileName) throws Exception {
        List<Person> person = new ArrayList<>();
        BufferedReader br = new BufferedReader(new FileReader(fileName));
        String line = br.readLine(); // 读取第一行标题行
        while ((line = br.readLine()) != null) {
            // 使用正则表达式匹配连续的空格
            String[] parts = line.split("\\s+");
            // 直接取整数部分
            int id = Integer.parseInt(parts[0]);
            String name = parts[1].trim();
            double height = Double.parseDouble(parts[2].trim());
            String category = parts[3].trim();
            Person pers = new Person(id, name, height, category, 0);
            person.add(pers);
        }
        br.close();
        return person;
    }

1.计算测试记录(unknownPerson)和数据集中每个元组的身高差,并将其存储在 heightDifference 属性中。
2.根据计算得到的身高差,选择数据集中距离测试记录最近的 k 个元组,并将其存储在 nearestNeighbors 列表中。

//前 k 个人加入到 nearestNeighbors 列表中,以便后续的分类过程中使用。
 List<Person> nearestNeighbors = new ArrayList<>();
        for (int i=0;i<person.size();i++){
            Person pers=person.get(i);
            Double difference= pers.height-unknownPerson.height;
            // 使用 DecimalFormat 进行舍入保留两位小数
            DecimalFormat df = new DecimalFormat("#.##");
            pers.heightDifference = Double.parseDouble(df.format(Math.abs(difference)));
            System.out.println((i+1)+" "+pers.heightDifference);

        }
        System.out.println();

        for (int i = 0; i < k; i++) {
            Person pers = person.get(i);
            nearestNeighbors.add(pers);
        }

对于每个元素,获取其身高差值,并与已知的 k 个最近邻的身高差值进行比较。
如果当前元素的身高差值小于已知的 k 个最近邻中的最大差值(即当前最远的邻居),则将当前元素替换为最远的邻居,以保持 k 个最近邻的正确性。

  for (int i = k; i < person.size(); i++) {
            Person pers = person.get(i);
            double maxDifference = 0;
            int maxIndex = 0;
            for (int j = 0; j < k; j++) {
                if (nearestNeighbors.get(j).heightDifference > maxDifference) {
                    maxDifference = nearestNeighbors.get(j).heightDifference ;
                    maxIndex = j;
                }
            }
            if (pers.heightDifference < maxDifference) {
                nearestNeighbors.set(maxIndex, pers);

            }
        }

多数投票法(Majority Voting)。在 k-最近邻算法中,通过找到最近的 k 个邻居,并统计它们所属的类别,然后选择出现次数最多的类别作为预测结果。

 // 找到出现次数最多的类别
        String predictedCategory = "";
        int maxCount = 0;
        for (Map.Entry<String, Integer> entry : categoryCounts.entrySet()) {
            String category = entry.getKey();
            int count = entry.getValue();
            if (count > maxCount) {
                predictedCategory = category;
                maxCount = count;
            }
        }
2. 完整代码
  • Person.java
/**
 * Created by 23222 on 2023/12/12.
 */
public class Person
{
    int id;
    String name;
    double height;
    String category;
    double heightDifference;

    public Person(int id, String name, double height, String category, double heightDifference) {
        this.id = id;
        this.name = name;
        this.height = height;
        this.category = category;
        this.heightDifference = heightDifference;
    }


    @Override
    public String toString() {
        return "Person{" +
                "id=" + id +
                ", name='" + name + '\'' +
                ", height=" + height +
                ", category='" + category + '\'' +
                ", heightDifference=" + heightDifference +
                '}';
    }
}

  • kNN.java
import java.io.BufferedReader;
import java.io.FileReader;
import java.text.DecimalFormat;
import java.util.*;

/**
 * Created by 23222 on 2023/12/12.
 */
public class kNN {
    public static void main(String[] args) throws Exception {
        // 读取数据集
        List<Person> dataset = readData("../5-KNN/dataset.txt");
        Scanner scanner = new Scanner(System.in);
        int id = dataset.size();
        System.out.println("Enter the name:");
        String name = scanner.next();  // Assuming you want to read a single word
        System.out.println("Enter the height:");
        double height = scanner.nextDouble();
        System.out.println("Enter the K:");
        int k = scanner.nextInt();
        Person unknownPerson = new Person(id, name, height, null, 0);

        String predictedCategory = classify(k,unknownPerson, dataset);
        unknownPerson.category=predictedCategory;
        System.out.println("K="+k);
        System.out.println("预测结果:" + unknownPerson.name+"   "+unknownPerson.height+"----->"+unknownPerson.category);
    }


    private static String classify(int k, Person unknownPerson, List<Person> person) {
        // 计算测试记录和数据集中每个元组的身高差
        List<Person> nearestNeighbors = new ArrayList<>();
        for (int i=0;i<person.size();i++){
            Person pers=person.get(i);
            Double difference= pers.height-unknownPerson.height;
            // 使用 DecimalFormat 进行舍入保留两位小数
            DecimalFormat df = new DecimalFormat("#.##");
            pers.heightDifference = Double.parseDouble(df.format(Math.abs(difference)));
            System.out.println((i+1)+" "+pers.heightDifference);

        }
        System.out.println();

        for (int i = 0; i < k; i++) {
            Person pers = person.get(i);
            nearestNeighbors.add(pers);
        }
        for (int i = k; i < person.size(); i++) {
            Person pers = person.get(i);
            double maxDifference = 0;
            int maxIndex = 0;
            for (int j = 0; j < k; j++) {
                if (nearestNeighbors.get(j).heightDifference > maxDifference) {
                    maxDifference = nearestNeighbors.get(j).heightDifference ;
                    maxIndex = j;
                }
            }
            if (pers.heightDifference < maxDifference) {
                nearestNeighbors.set(maxIndex, pers);

            }
        }

        for (Person neighbor : nearestNeighbors) {
            System.out.println(neighbor);
        }

        // 统计k个元组中高矮中等类别的出现次数
        Map<String, Integer> categoryCounts = new HashMap<>();
        for (Person neighbor : nearestNeighbors) {
            String category = neighbor.category;
            categoryCounts.put(category, categoryCounts.getOrDefault(category, 0) + 1);
        }

        for (Map.Entry<String, Integer> entry : categoryCounts.entrySet()) {
            System.out.println(entry.getKey() + " : " + entry.getValue());
        }

        // 找到出现次数最多的类别
        String predictedCategory = "";
        int maxCount = 0;
        for (Map.Entry<String, Integer> entry : categoryCounts.entrySet()) {
            String category = entry.getKey();
            int count = entry.getValue();
            if (count > maxCount) {
                predictedCategory = category;
                maxCount = count;
            }
        }

        return predictedCategory;
    }

    private static List<Person> readData(String fileName) throws Exception {
        List<Person> person = new ArrayList<>();
        BufferedReader br = new BufferedReader(new FileReader(fileName));
        String line = br.readLine(); // 读取第一行标题行
        while ((line = br.readLine()) != null) {
            // 使用正则表达式匹配连续的空格
            String[] parts = line.split("\\s+");
            // 直接取整数部分
            int id = Integer.parseInt(parts[0]);
            String name = parts[1].trim();
            double height = Double.parseDouble(parts[2].trim());
            String category = parts[3].trim();
            Person pers = new Person(id, name, height, category, 0);
            person.add(pers);
        }
        br.close();
        return person;
    }

}

三、实验结果

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值