基于KNN分类算法手写数字识别的实现(二)——构建KD树

上一篇已经简单粗暴的建立了一个KNN模型对手写图片进行了识别,所以本篇文章采用构造KD树的方法实现手写数字的识别。

上一篇链接:https://blog.csdn.net/qq_33361618/article/details/82887121

(一)构造KD树

构造KD树的基本原理网上都有介绍,所以废话不多说,直接上代码。

#Knn KD_Tree算法

import math
from collections import namedtuple

#定义命名元祖,用来存放结果,最近点,最近距离和访问过的节点数
result = namedtuple('Result_tuple', 'nearest_point nearest_dist nodes_visited')


# In[5]:


#构造KD树

#初始化构造KD树的元素
class KD_Node(object):
    
    def __init__(self, dom_elt, split, left, right):
        
        self.dom_elt = dom_elt #k维向量节点
        self.split = split     #整数,进行分割的序号
        self.left = left       #该节点分割超平面的左子树
        self.right = right     #该节点分割超平面的右子树
        
class KD_Tree(object):
    
    def __init__(self, data):
        
        k = len(data[0])       #数据的维度
        
        def Create_Node(split, data_set): #按第split维划分数据data_set创建的KD_Node
            
            if (data_set == []):       #数据集为空

                return None
            #key参数的值为一个函数,此函数只有一个参数且返回一个值来进行比较
            #operator模块提供的itemgetter函数用来获取对象有哪些维的数据,
            #参数为需要获取的数据对象中的序号
            data_set = list(data_set)
            data_set.sort(key=lambda x: x[split])
            split_positon = len(data_set) // 2 #//代表整除
            median = data_set[split_positon] #中位数
            split_next = (split + 1) % k 
            #递归创建KD数
            return KD_Node(median, split,
                          Create_Node(split_next, data_set[:split_positon]),
                          Create_Node(split_next, data_set[split_positon + 1:]))
        
        self.root = Create_Node(0, data)
        
#KD树的前序遍历
def Pre_Order(root):
    
#     print(root.dom_elt)
    if (root.left):
        Pre_Order(root.left)
    if (root.right):
        Pre_Order(root.right)

KD树构造完成后,可以计算最近邻。

#搜索最近邻

def Find_Nearest(tree, point):
    
    k = len(point) #数据维度
    
    def Travel(kd_node, target, max_dist):
        
        if kd_node is None:
            
            return result([0] * k, float("inf"), 0)#inf表示正无穷,-inf表示负无穷
        
        nodes_visited = 1
        s = kd_node.split  #进行分割的维度
        pivot = kd_node.dom_elt #进行分割的轴
        
        if target[s] <= pivot[s]: #如果目标点第s维小于分割轴对应值,即目标离左子树更近
            
            nearer_node = kd_node.left #下一个访问的点为左子树
            further_node = kd_node.right #同时记录右子树
        else:                     #目标离右子树较近的时候
            
            nearer_node = kd_node.right #下一个访问点为右子树根节点
            further_node = kd_node.left #记录左子树
        
        temp1 = Travel(nearer_node, target, max_dist) #遍历找到包含目标点的位置
        nearest = temp1.nearest_point #以此节点作为“当前最近点”
        dist = temp1.nearest_dist     #更新最近距离
        nodes_visited += temp1.nodes_visited
        
        if dist < max_dist:
            
            max_dist = dist #最近点将在以目标点为圆心,max_dist为半径的超球体内
        
        temp_dist = abs(pivot[s] - target[s]) #第s维上目标点与分割超平面的距离
        
        if max_dist < temp_dist: #判断超球体是否与分割平面相交
            
            return result(nearest, dist, nodes_visited)
            
        #计算目标点与分割点的欧氏距离
        temp_dist = math.sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target)))
        
        if temp_dist < dist: #如果得到更近的点
            
            nearest = pivot  #更新更近的点
            dist = temp_dist #更新最近的距离
            max_dist = dist  #更新超球体半径
        
        #检查另一个子节点对应的区域是否有更近的点
        temp2 = Travel(further_node, target, max_dist)
        nodes_visited += temp2.nodes_visited
        
        if temp2.nearest_dist <  dist: #如果另一个子节点中存在更近的距离
            
            nearest = temp2.nearest_point #更新最近的点
            dist = temp2.nearest_dist     #更新最近距离
        
        return result(nearest, dist, nodes_visited)
    
    return Travel(tree.root, point, float("inf")) #从根节点开始递归

测试结果,计算[2,4.5]离数据集:[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]中最近的点。

if __name__ == "__main__":
    
    data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    kd = KD_Tree(data)
    rst = Find_Nearest(kd, [2,4.5])

[2,4.5]最近邻为[2,3],最短距离为1.5.测试结果看出KD树的效果还是不错的。那么在大数据高维度情况下,KD树的测试结果怎样呢。

(二)对比蛮力实现和KD树实现的区别

对之前处理的1万条样本数据选择8000条作为训练集,2000条作为检验集。

考虑到代码行较多的情况,本次对比使用封装模块,然后调用模块运行测试结果。

生成3个.py文件:Sample.py、Knn.py和KD_Tree.py

此部分代码与前面的代码区别不大,就不再进行复制。如有需要可以在网页链接中下载,提取码: po7s。

执行文件为Main,py

import sys
sys.path.append(r"D:/Python_work/机器学习/KNN分类算法/Knn")

from Sample import Sample_PC
from datetime import datetime


#调用参数
k = 3
train_file_route = r"E:/data/digit_data_copy/train/"
test_file_route = r"E:/data/digit_data_copy/test/"
model = "KD_Tree"


#执行蛮力实现
func1 = Sample_PC(3,train_file_route, test_file_route,None)
t1 = datetime.now()
result1 = func1.test_data()
t2 = datetime.now()
print('knn耗时:', t2-t1)


#执行KD树实现
func2 = Sample_PC(3,train_file_route, test_file_route, model=model)
t3 = datetime.now()
result2 = func2.test_data()
t4 = datetime.now()
print('KD_Tree耗时:', t4-t3)

结论:

蛮力实现:准确率:0.977,耗时:2分56秒

混淆矩阵

 file_name
forecast_data0123456789
real_data          
0209100001100
1022100000000
2201630100200
3000206010210
4011020911002
5010101723001
6010000184000
7040000020300
8111102011980
9120210040189

KD树实现:准确率:0.989,耗时:1个小时53分钟

混淆矩阵:

虽然,KD树的准确率在蛮力实现之上,但KD树对于高维大数据的计算大过于耗费时间,且准确率提升也不是特别高。总体而言,knn分类效果较好,但计算比较耗时,这也是它最大的一个缺点。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值