用python手写KNN算法+kd树及其BBF优化(原理与实现)(上篇)
初学python和机器学习,突然兴起想动手用python实践一下KNN算法,本来想着这个算法原理很简单明了,应该实现起来没什么大问题,然而真正上手的时候问题频出,花了好一些功夫挨个排除各种奇怪的bug,总算是大功告成。接下来我会介绍一下算法的手写实现和在此过程中亲遇的各种问题,希望能够帮到大家。实验所需数据链接在文章最后。
ps:从学习C语言以来形成了print调试的毛病,所以在代码中保留了一些用于调试的输出重要信息的print语句,放在了后面的完整print信息版代码中,方便理解代码和调试。千万注意在使用time()方法测试程序运行时间时要把这些调试的print语句注释掉。
pps:关于numpy的疑惑建议随时参考https://www.runoob.com/numpy/numpy-dtype.html,很方便
1. KNN算法与kd树简介
1.1 什么是KNN算法?
网上关于KNN的详细介绍很多,简单来说,KNN是一种有监督分类算法,通过计算待分类数据点,与已有数据集中的所有数据点的距离。取距离最小的前K个点,根据“少数服从多数“的原则,将这个数据点划分为出现次数最多的那个类别。如图由KNN得到Xu属于ω1
因此,将分类点输入的过程就是KNN算法的学习过程,将已分类点全部输入后,要完成对未分类点所属类别的预测,重点是找出距离未分类点最近的前K个已分类点
1.2 为什么需要kd树?
前面我们说道,要完成对未分类点所属类别的预测,重点是找出距离未分类点最近的前K个已分类点。那么,对于每个未分类点,一般我们需要求出它与所有以分类点的距离,然后找出前k个距离最小的已分类点。如果已分类点集合中有n个点,那么如果我们要对m个未分类点进行预测,时间复杂度为O(m*n)。当n很大时,我们认为这样不是很高效。
那么,有没有一种方法让上述复杂度变为O(mlogn)呢?这时我们想到了二叉树。类比二叉查找树(BST),Kd-Tree即K-dimensional tree,是一棵二叉树,树中存储的是一些K维数据。在一个K维数据集合上构建一棵Kd-Tree代表了对该K维数据集合构成的K维空间的一个划分。即树中的每一个结点就相应了一个K维的超矩形区域(Hyperrectangle),kd树的详细介绍以及如何构造kd树将在下面介绍。
2. 数据集准备
首先,准备数据集:这里的数据集即指KNN算法的训练集和测试集。对于KNN算法来说,将训练集输入的过程就是KNN算法学习的过程。训练集和测试集由多个样本构成,每个样本由其特征向量和标签构成,也就是由特征和类别构成。举个例子,某样本的特征向量为(唱,跳,rap,篮球),标签为蔡徐坤,将它作为训练集输入后,测试集中我们给出(唱,跳,rap,鸡你太美),由KNN算法我们预测出该测试样本对应标签为蔡徐坤,和测试标签对比发现本次预测成功。为了方便编程,我们将训练集和测试集处理均处理为由(特征1,特征2,… ,特征n,标签)这样的向量组成的集合,称之为数据矩阵。如:
训练集(或测试集):
唱,跳,rap,篮球,蔡徐坤
拐,黑土,不差钱,小品,赵本山
…
为了简单这次实验用的是DBRHD数据集。
2.1 DBRHD数据集
DBRHD(Pen-Based Recognition of Handwritten Digits Data Set)是UCI的机器学习中心提供的数字手写体数据库可以在https://archive.ics.uci.edu/ml/datasets/PenBased+Recognition+of+Handwritten+Digits下载,不过我相信从这里得到的数据集会让你一头雾水,所以我会把我用到的数据文本放到文章后面的链接中。
DBRHD数据集包含大量的数字0~9的手写体图片,这些图片来源于44位不同的人的手 写数字,图片已归一化为以手写数字为中心的32*32规格的图片。DBRHD的训练集与测试集组成如下:
(1)训练集:7,494个手写体图片及对应标签,来源于40位手写者
(2)测试集:3,498个手写体图片及对应标签,来源于14位手写者
我们把训练集和测试集转化为前面介绍的向量集合的格式存放到文本中,分为两个版本:
(1)特征个数为16的版本:
训练集training1.txt:
其中每一行代表一个(特征1,特征2,…,特征16,标签)的向量,
如47,100,27,81,57,37,26,0,0,23,56,53,100,90,40,98这16个特征决定了它代表数字8。下面的测试集也是类似。
测试集test1.txt:
(2)特征个数为1024的版本(这一版训练集样本有1934个,测试集样本有946个):
训练集training2.txt和测试集test2.txt
(太占版面了,只贴一个向量吧,前面1024个0或1组成的特征代表数字8)
2.2 编写数据读取函数
先导入这次实验需要的全部模块
import numpy as np
import queue # 后续bbf会用
import time
读取文件函数:
def loadData(filePath): # 读文件
with open(filePath, 'r+') as fr:
# with语句会自动调用close()方法,且比显式调用更安全
lines = fr.readlines()
data = []
for line in lines: # 逐行读入
items = line.strip().split(",")
data.append([int(items[i]) for i in range(len(items))])
return np.asarray(data) # 以np.ndarray类型数组返回
3. 构建kd树(kd-tree)
得到数据后我们就可以构建kd树了,KNN算法其实本身并没有真正意义的学习的过程,构建kd树的过程就作为它的“学习”过程。
首先,我们要知道什么是kd树:
我们先回想一下二叉查找树(或二叉排序树)即BST:
二叉查找树(Binary Search Tree,BST)。是具有例如以下性质的二叉树:
1)若它的左子树不为空。则左子树上全部结点的值均小于它的根结点的值;
2)若它的右子树不为空,则右子树上全部结点的值均大于它的根结点的值;
3)它的左、右子树也分别为二叉排序树;
如图是一棵BST: