一、基本介绍
k-近邻算法又称kNN,全称是k-Nearest Neighbors算法,它是数据挖掘和机器学习中常用的学习算法,也是机器学习中最简单的分类算法之一。kNN算法用一句通俗的古语来说就是:“物以类聚,人以群分”,要判断一个实例的类别,就可以看它附近是什么类别。kNN的使用范围很广泛,在样本量足够大的前提条件之下它的准确度非常高。
二、核心思想
计算每个训练数据到待分类元组的距离,取和待分类元组最近的k个训练数据,k个数据中哪个类别的训练数据占多数,则待分类元组就属于哪个类别。
三、原理演示
首先计算每个训练数据到待分类元组的距离:
取k=5,选择k个和待分类元组最近的k个训练数据:
k个数据中红色样本占多数,则待分类元组分类为红色。
四、算法流程图
五、关键源码展示
1、导入数据
2、计算距离,进行分类
3、输出分类结果
六、拓展实验
人工添加二维点阵数据集,测试在二维数据下的分类效果
七、完整代码与数据集
1、完整代码
import numpy as np
import collections
def loadData():
try:
with open("kNN.txt", "r") as f:
flines = f.readlines()
Data = []
Labels = []
for item in flines[0:]:
Data.append(float(item.strip().split(',')[0]))
Labels.append(item.strip().split(',')[1])
print(Data)
print(Labels)
return Data, Labels
except Exception as e:
print(e)
# def loadDataDemo():
# try:
# with open("Demo.txt", "r") as f:
# flines = f.readlines()
# Data = []
# Labels = []
# for item in flines[0:]:
# Data.append([float(item.split(',')[0]), float(item.split(',')[1])])
# Labels.append(item.strip().split(',')[2])
# print(Data)
# print(Labels)
# return Data, Labels
# except Exception as e:
# print(e)
def classify(test, dataset, label, k):
# 计算距离
dist = [0] * len(dataset)
for i in range(len(dataset)):
dist[i] = round(abs(test - dataset[i]), 2)
# for i in range(len(dataset)):
# dist[i] = round(((test[0] - dataset[i][0]) ** 2 + (test[1] - dataset[i][1]) ** 2) ** 0.5, 2)
print('dist={}'.format(dist))
d = np.array(dist)
# k个最近的标签
k_labels = [label[index] for index in d.argsort()[0: k]]
print('k_labels(k={})={}'.format(k, k_labels))
# 出现次数最多的标签即为最终类别
label = collections.Counter(k_labels).most_common(1)[0][0]
return label
def main():
# group, labels = loadDataDemo()
# test = [2, 4]
# k = 13
# test_class = classify(test, group, labels, k)
# print('Label(k={})={}'.format(k, test_class))
#
# test = [8, 7]
# k = 13
# test_class = classify(test, group, labels, k)
# print('Label(k={})={}'.format(k, test_class))
# 创建数据集
group, labels = loadData()
# kNN分类
test = 1.64
print('test={}'.format(test))
for k in range(3, 6):
test_class = classify(test, group, labels, k)
print('Label(k={})={}'.format(k, test_class))
print()
test = 1.74
print('test={}'.format(test))
for k in range(3, 6):
test_class = classify(test, group, labels, k)
print('Label(k={})={}'.format(k, test_class))
print()
test = 1.84
print('test={}'.format(test))
for k in range(3, 6):
test_class = classify(test, group, labels, k)
print('Label(k={})={}'.format(k, test_class))
if __name__ == "__main__":
main()
2、数据集
(1)原始数据集
1.5,low
1.92,high
1.7,medium
1.73,medium
1.6,low
1.75,medium
1.6,low
1.9,high
1.68,medium
1.78,medium
1.70,medium
1.68,medium
1.65,low
1.78,medium
(2)拓展数据集
1,4,red
1,3,red
2,3,red
2,2,red
3,2,red
4,2,red
5,2,red
6,2,red
7,2,red
8,2,red
8,3,red
9,3,red
9,4,red
1,7,blue
1,8,blue
2,8,blue
2,9,blue
3,9,blue
4,9,blue
5,9,blue
6,9,blue
7,9,blue
8,9,blue
8,8,blue
9,8,blue
9,7,blue