knn算法java实现鸢尾花,原生Python实现KNN算法(使用鸢尾花数据集)

一.题目:

原生python实现knn分类算法(使用鸢尾花数据集)

K最近邻(KNN,K-nearestNeighbor)分类算法的核心思想是如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本属于也属于这个类别,并具有这个类别样本上的特性。

即选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签。所以要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离(欧式距离),选取距离最近的k个样本,获取他们的标签,然后找出k个样本中数量最多的标签,返回该标签。

欧式距离:

m维空间的距离公式,d2 = (x1 - y1)2 + (x2 - y2)2  + ... + (xm - ym)2

二.算法设计:

1.导入数据:从CSV中读取数据,并把它们分割成训练数据集和测 试数据集。

数据集获取地址:

https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data

2.计算数据集中训练集数据与测试集两个数据实例之间的距离(欧氏距离)。

3.返回临近确定最相近的N个实例。

4.返回预测结果,对k个近邻进行合并,返回value最大的key 。

5.计算准确度,总结预测的准确度。

三.源代码:

import csv #用于处理csv文件

import random #用于随机数

import math

import operator

#加载数据集

def loadDataset(filename,split,trainingSet=[],testSet = []):

with open(filename,"r") as csvfile:

lines = csv.reader(csvfile)

dataset = list(lines)

for x in range(len(dataset)-1):

for y in range(4):

dataset[x][y] = float(dataset[x][y])

if random.random()预测=' ,repr(result),',实际=',repr(testSet[x][-1]))

accuracy = getAccuracy(testSet, predictions)

print('准确率:',repr(accuracy),'%')

if __name__ =="__main__":

main()

四.调试:

1.加载数据文件

(错误的:原因是读取数据集文件错误,因为open默认读取的是文本文件,此处的数据集保存形式不是文本文件,将rb改为r就可以了)

f25c59e47daafed066ffc5e440ba754c.png

(改正后:)

e750cff8e594ca82a2950e4590f7291f.png

2.k临近调试

aa01320f565b3f4aef6b9d55c1d2e901.png

五.测试:

1.

数据集切分测试代码:

dbbb062ad5c4600ca518d2f3e3afd225.png

随机切分结果:

5627b3a47ca80fe1420a4e8fc1767387.png

2.

k最临近测试代码:

22e12e45c9da44c965feb850192528f8.png

测试结果:

9cf179df62d8dbd6c5808dd933068e64.png

3.

欧氏距离测试代码:

74ab65fcaae1e43141d87c71d044bb67.png

测试结果:

c1d2d9d58ee59da8d69672d664b15410.png

六.运行结果:、

a9e325500964ca55bf77d64a06f4d1e2.png

364417f7bbe4ef0bd0cd608b0505f230.png

七.总结:

学习参考了CSDN文章:

https://blog.csdn.net/gxlmsw1314/article/details/51684104?utm_source=app

在编程过程中的第一步就出现了错误,主要原因是对open打开文件的默认格式不清楚。接下来是对k个最临近数据进行排序时浪费了很长时间,找了很多方法,最终使用operator.itemgetter(1)这个方法完成了排序。

收获:对于python的语法有了进一步的了解,而且将最基础的语法也进行了学习(因为一开始我的输出语句都能写错),也知道了几个用于对数据字典进行操作的函数。最主要的是这次将字典,列表的知识掌握的很熟悉。

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,对于knn算法Java实现,我们可以按照以下步骤进行: 1. 读取数据集 2. 对数据集进行预处理,例如数据归一化、特征选择等 3. 将数据集划分为训练集和测试集 4. 针对测试集中的每个样本,计算其与训练集中所有样本的距离 5. 对距离计算结果进行排序,选择前k个距离最近的样本 6. 根据这k个样本的类别,使用投票法确定测试样本所属的类别 7. 计算模型的准确率 相关代码实现如下: ``` // 1. 读取数据集 List<List<Double>> dataSet = new ArrayList<>(); List<Integer> labels = new ArrayList<>(); try { BufferedReader reader = new BufferedReader(new FileReader("iris.data")); // 数据集文件路径 String line; while ((line = reader.readLine()) != null) { String[] fields = line.split(","); List<Double> dataRow = new ArrayList<>(); for (int i = 0; i < fields.length - 1; i++) { dataRow.add(Double.parseDouble(fields[i])); } dataSet.add(dataRow); labels.add(Integer.parseInt(fields[fields.length - 1])); } reader.close(); } catch (IOException e) { e.printStackTrace(); } // 2. 数据预处理(省略) // 3. 将数据集划分为训练集和测试集 int n = dataSet.size(); int m = n * 7 / 10; // 训练集占比70% List<List<Double>> trainData = new ArrayList<>(); List<Integer> trainLabels = new ArrayList<>(); List<List<Double>> testData = new ArrayList<>(); List<Integer> testLabels = new ArrayList<>(); List<Integer> indexList = new ArrayList<>(); for (int i = 0; i < n; i++) { indexList.add(i); } Collections.shuffle(indexList); // 打乱数据集顺序 for (int i = 0; i < m; i++) { trainData.add(dataSet.get(indexList.get(i))); trainLabels.add(labels.get(indexList.get(i))); } for (int i = m; i < n; i++) { testData.add(dataSet.get(indexList.get(i))); testLabels.add(labels.get(indexList.get(i))); } // 4. 计算距离和排序 int k = 5; // k值 int errorCount = 0; for (int i = 0; i < testData.size(); i++) { List<Double> testRow = testData.get(i); PriorityQueue<Pair<Double, Integer>> pq = new PriorityQueue<>((a, b) -> -Double.compare(a.getKey(), b.getKey())); // 大根堆 for (int j = 0; j < trainData.size(); j++) { List<Double> trainRow = trainData.get(j); double dist = 0; for (int c = 0; c < testRow.size(); c++) { dist += Math.pow(testRow.get(c) - trainRow.get(c), 2); } pq.offer(new Pair<>(dist, trainLabels.get(j))); if (pq.size() > k) { pq.poll(); } } // 6. 投票法 int[] count = new int[3]; Arrays.fill(count, 0); for (Pair<Double, Integer> pair : pq) { count[pair.getValue() - 1]++; } int predict = 1; for (int j = 1; j < 3; j++) { if (count[j] > count[predict - 1]) { predict = j + 1; } } if (predict != testLabels.get(i)) { errorCount++; } } // 7. 计算准确率 double accuracy = (double) (testData.size() - errorCount) / testData.size(); System.out.println("Accuracy: " + accuracy); ``` 其中,我们假设数据集的类别只有3种,分别为1、2、3。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值