什么是K-邻近算法?
K-邻近算法(k-NearestNeighbor)简称KNN,是分类算法中的一种。KNN通过计算新数据与历史样本数据中不同类别数据点间的距离对新数据进行分类。简单来说就是通过与新数据点最邻近的K个数据点来对新数据进行分类和预测。K-邻近分类算法是数据挖掘(classification)技术中最简单的算法之一,其指导思想是”近朱者赤,近墨者黑“,即由你的邻居来推断出你的类别。
k-邻近算法工作原理
存在一个样本数据集合,也称作训练样本集,并且样本集合每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最邻近)的分类标签。一般来说我们只选择样本数据集中前K个最相似的数据,这就是K-邻近算法中K的出处,通常K是不大于20的整数。最后,选择K个最相似数据中出现最多的分类,作为新数据的分类。K邻近模型由三个基本要素–距离度量、K值选择和分类决策规则决定。
K-邻近算法的一个经典例子
基于电影中的搞笑、拥抱、打斗镜头,使用 k-近邻算法构造程序,就可以自动划分电影的题材类型。(图来自机器学习之KNN(k近邻)算法详解)
d为《唐人街探案》与列表中各个电影的欧氏距离
如图可知与《唐人街探案》距离最近的五个电影中有4个喜剧片,一个爱情片,由此可推出《唐人街探案》是一部喜剧片。
K-邻近算法距离计算的两种方法
K-邻近算法的一般流程
(1)收集数据:可以使用任何方法。
(2)准备数据:距离计算所需要的数值,最好是结构化的数据格式。
(3)分析数据:可以使用任何方法
(4)训练算法:此算法不适用于K-邻近算法
(5)测试算法:计算错误率
(6)使用算法:首先需要输入样本数据和结构化的输出结果,然后运行K-邻近算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理。
K-邻近算法实例:数字手写体的识别
import numpy as np
from image import image2onebit as it
import sys
from tensorflow.examples.tutorials.mnist import input_data
import math
import datetime
#KNN算法主体:计算测试样本与每一个训练样本的距离
def get_index(train_data,test_data, i):
#1、 np.argmin(np.sqrt(np.sum(np.square(test_data[i]-train_data),axis=1)))
#2、a数组存入:测试样本与每一个训练样本的距离
all_dist = np.sqrt(np.sum(np.square(test_data[i]-train_data),axis=1)).tolist()
return all_dist
#KNN算法主体:计算查找最近的K个训练集所对应的预测值
def get_number(all_dist):
all_number = []
min_index = 0
#print('距离列表:', all_dist,)
for k in range(Nearest_Neighbor_number):
# 最小索引值 = 最小距离的下标编号
min_index = np.argmin(all_dist)
#依据最小索引值(最小距离的下标编号),映射查找到预测值
ss = np.argmax((train_label[min_index])).tolist()
print('第',k+1,'次预测值:',ss)
#将预测值改为字符串形式存入新元组bb中
all_number = all_number + list(str(ss))
#在距离数组中,将最小的距离值删去
min_number = min(all_dist)
xx = all_dist.index(min_number)
del all_dist[xx]
print('预测值总体结果:',all_number)
return all_number
#KNN算法主体:在K个预测值中,求众数,找到分属最多的那一类,输出
def get_min_number(all_number):
c = []
#将string转化为int,传入新列表c
for i in range(len(all_number)):
c.append(int(all_number[i]))
#求众数
new_number = np.array(c)
counts = np.bincount(new_number)
return np.argmax(counts)
t1 = datetime.datetime.now() #计时开始
print('说明:训练集数目取值范围在[0,60000],K取值最好<10\n' )
train_sum = int(input('输入训练集数目:'))
Nearest_Neighbor_number = int(input('选取最邻近的K个值,K='))
#依照文件名查找,读取训练与测试用的图片数据集
mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)
#取出训练集数据、训练集标签
train_data, train_label = mnist.train.next_batch(train_sum)
#调用自创模块内函数read_image():依照路径传入图片处理,将图片信息转换成numpy.array类型
x1_tmp = it.read_image("png/55.png")
test_data = it.imageToArray(x1_tmp)
test_data = np.array(test_data)
#print('test_data',test_data)
#调用自创模块内函数show_ndarray():用字符矩阵打印图片
it.show_ndarray(test_data)
#KNN算法主体
all_dist = get_index(train_data,test_data,0)
all_number = get_number(all_dist)
min_number = get_min_number(all_number )
print('最后的预测值为:',min_number)
t2=datetime.datetime.now()
print('耗 时 = ',t2-t1)
评价:使用的训练集、测试集数据来源于Google的那个经典的MNIST手写数字数据集。程序限制图片数据大小是28*28的,也就是说像素点一共784个,所以缺陷在于(应该说是KNN算法缺陷硬伤)
大多数数据图片占据的像素点很接近,距离区分度比较低;
未考虑不同数字间的内部结构特征
总结K-邻近算法的优缺点
优点:简单,易于理解,无需建模与训练,易于实现;适合对稀有事件进行分类;适合与多分类问题,例如根据基因特征来判断其功能分类。
缺点:惰性算法,内存开销大,对测试样本分类时计算量大,性能较低;可解释性差,无法给出决策树那样的规则。