上一篇已经简单粗暴的建立了一个KNN模型对手写图片进行了识别,所以本篇文章采用构造KD树的方法实现手写数字的识别。
上一篇链接:https://blog.csdn.net/qq_33361618/article/details/82887121
(一)构造KD树
构造KD树的基本原理网上都有介绍,所以废话不多说,直接上代码。
#Knn KD_Tree算法
import math
from collections import namedtuple
#定义命名元祖,用来存放结果,最近点,最近距离和访问过的节点数
result = namedtuple('Result_tuple', 'nearest_point nearest_dist nodes_visited')
# In[5]:
#构造KD树
#初始化构造KD树的元素
class KD_Node(object):
def __init__(self, dom_elt, split, left, right):
self.dom_elt = dom_elt #k维向量节点
self.split = split #整数,进行分割的序号
self.left = left #该节点分割超平面的左子树
self.right = right #该节点分割超平面的右子树
class KD_Tree(object):
def __init__(self, data):
k = len(data[0]) #数据的维度
def Create_Node(split, data_set): #按第split维划分数据data_set创建的KD_Node
if (data_set == []): #数据集为空
return None
#key参数的值为一个函数,此函数只有一个参数且返回一个值来进行比较
#operator模块提供的itemgetter函数用来获取对象有哪些维的数据,
#参数为需要获取的数据对象中的序号
data_set = list(data_set)
data_set.sort(key=lambda x: x[split])
split_positon = len(data_set) // 2 #//代表整除
median = data_set[split_positon] #中位数
split_next = (split + 1) % k
#递归创建KD数
return KD_Node(median, split,
Create_Node(split_next, data_set[:split_positon]),
Create_Node(split_next, data_set[split_positon + 1:]))
self.root = Create_Node(0, data)
#KD树的前序遍历
def Pre_Order(root):
# print(root.dom_elt)
if (root.left):
Pre_Order(root.left)
if (root.right):
Pre_Order(root.right)
KD树构造完成后,可以计算最近邻。
#搜索最近邻
def Find_Nearest(tree, point):
k = len(point) #数据维度
def Travel(kd_node, target, max_dist):
if kd_node is None:
return result([0] * k, float("inf"), 0)#inf表示正无穷,-inf表示负无穷
nodes_visited = 1
s = kd_node.split #进行分割的维度
pivot = kd_node.dom_elt #进行分割的轴
if target[s] <= pivot[s]: #如果目标点第s维小于分割轴对应值,即目标离左子树更近
nearer_node = kd_node.left #下一个访问的点为左子树
further_node = kd_node.right #同时记录右子树
else: #目标离右子树较近的时候
nearer_node = kd_node.right #下一个访问点为右子树根节点
further_node = kd_node.left #记录左子树
temp1 = Travel(nearer_node, target, max_dist) #遍历找到包含目标点的位置
nearest = temp1.nearest_point #以此节点作为“当前最近点”
dist = temp1.nearest_dist #更新最近距离
nodes_visited += temp1.nodes_visited
if dist < max_dist:
max_dist = dist #最近点将在以目标点为圆心,max_dist为半径的超球体内
temp_dist = abs(pivot[s] - target[s]) #第s维上目标点与分割超平面的距离
if max_dist < temp_dist: #判断超球体是否与分割平面相交
return result(nearest, dist, nodes_visited)
#计算目标点与分割点的欧氏距离
temp_dist = math.sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target)))
if temp_dist < dist: #如果得到更近的点
nearest = pivot #更新更近的点
dist = temp_dist #更新最近的距离
max_dist = dist #更新超球体半径
#检查另一个子节点对应的区域是否有更近的点
temp2 = Travel(further_node, target, max_dist)
nodes_visited += temp2.nodes_visited
if temp2.nearest_dist < dist: #如果另一个子节点中存在更近的距离
nearest = temp2.nearest_point #更新最近的点
dist = temp2.nearest_dist #更新最近距离
return result(nearest, dist, nodes_visited)
return Travel(tree.root, point, float("inf")) #从根节点开始递归
测试结果,计算[2,4.5]离数据集:[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]中最近的点。
if __name__ == "__main__":
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
kd = KD_Tree(data)
rst = Find_Nearest(kd, [2,4.5])
[2,4.5]最近邻为[2,3],最短距离为1.5.测试结果看出KD树的效果还是不错的。那么在大数据高维度情况下,KD树的测试结果怎样呢。
(二)对比蛮力实现和KD树实现的区别
对之前处理的1万条样本数据选择8000条作为训练集,2000条作为检验集。
考虑到代码行较多的情况,本次对比使用封装模块,然后调用模块运行测试结果。
生成3个.py文件:Sample.py、Knn.py和KD_Tree.py
此部分代码与前面的代码区别不大,就不再进行复制。如有需要可以在网页链接中下载,提取码: po7s。
执行文件为Main,py
import sys
sys.path.append(r"D:/Python_work/机器学习/KNN分类算法/Knn")
from Sample import Sample_PC
from datetime import datetime
#调用参数
k = 3
train_file_route = r"E:/data/digit_data_copy/train/"
test_file_route = r"E:/data/digit_data_copy/test/"
model = "KD_Tree"
#执行蛮力实现
func1 = Sample_PC(3,train_file_route, test_file_route,None)
t1 = datetime.now()
result1 = func1.test_data()
t2 = datetime.now()
print('knn耗时:', t2-t1)
#执行KD树实现
func2 = Sample_PC(3,train_file_route, test_file_route, model=model)
t3 = datetime.now()
result2 = func2.test_data()
t4 = datetime.now()
print('KD_Tree耗时:', t4-t3)
结论:
蛮力实现:准确率:0.977,耗时:2分56秒
混淆矩阵
file_name | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|
forecast_data | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
real_data | ||||||||||
0 | 209 | 1 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 |
1 | 0 | 221 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 2 | 0 | 163 | 0 | 1 | 0 | 0 | 2 | 0 | 0 |
3 | 0 | 0 | 0 | 206 | 0 | 1 | 0 | 2 | 1 | 0 |
4 | 0 | 1 | 1 | 0 | 209 | 1 | 1 | 0 | 0 | 2 |
5 | 0 | 1 | 0 | 1 | 0 | 172 | 3 | 0 | 0 | 1 |
6 | 0 | 1 | 0 | 0 | 0 | 0 | 184 | 0 | 0 | 0 |
7 | 0 | 4 | 0 | 0 | 0 | 0 | 0 | 203 | 0 | 0 |
8 | 1 | 1 | 1 | 1 | 0 | 2 | 0 | 1 | 198 | 0 |
9 | 1 | 2 | 0 | 2 | 1 | 0 | 0 | 4 | 0 | 189 |
KD树实现:准确率:0.989,耗时:1个小时53分钟
混淆矩阵:
虽然,KD树的准确率在蛮力实现之上,但KD树对于高维大数据的计算大过于耗费时间,且准确率提升也不是特别高。总体而言,knn分类效果较好,但计算比较耗时,这也是它最大的一个缺点。