基于KNN的K-D树的创建与搜索及代码实现

KNN

  • 输入:特征向量
  • 输出:分类
  • 在训练集中寻找与当前点最近的k个点,然后根据例如多数表决等规则进行分类
How to define the distance of two vectors
  • Lp distance
import numpy as np
x_1 = np.array([1, 2, 3, 4])
x_2 = np.array([5, 6, 7, 8])
def Lp_distance(p, x_1, x_2):
    x = x_1 - x_2
    x = np.abs(x)
    x = np.power(x, p)
    Sum = np.sum(x)
    x = np.power(Sum, 1/p)
    return x
Lp_distance(1, x_1, x_2)
16.0
The value of K
  • k is small -> the neighbor is small -> only a sample who has a short distance will be seen as the same class
  • more sensitive -> if the neighbor sample is noise -> wrong answer
  • the smaller is the k, the more complicated is model
  • Normally, we used the cross validation
Build the balanced K-D tree
  • for the depth = j, the divide feature l is l = (j mod k) + 1
class Node:
    def __init__(self, data, lchild = None, rchild = None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild
class KD_Tree:
    def __init__(self):
        self.kd_tree = None
        
    def create(self, dataset, depth):
        if len(dataset) > 0:
            m, n = np.shape(dataset) # m:行数; n:列数
            midIndex = int(m / 2)
            axis = depth % n         # 划分轴
            sorted_dataset = self.sort(dataset, axis)
            
            node = Node(sorted_dataset[midIndex])
            left_dataset = sorted_dataset[:midIndex]
            right_dataset = sorted_dataset[midIndex + 1:]
            
            node.lchild = self.create(left_dataset, depth + 1)
            node.rchild = self.create(right_dataset, depth + 1)
            
            return node
        else:
            return None
    
    def sort(self, dataset, axis):
        sorted_dataset = dataset[:]
        m, n = np.shape(sorted_dataset)
        for i in range(m-1):
            temp = i
            for j in range(i, m):
                if sorted_dataset[temp][axis] > sorted_dataset[j][axis]:
                    temp = j

            t = sorted_dataset[temp]
            sorted_dataset[temp] = sorted_dataset[i]
            sorted_dataset[i] = t
        return sorted_dataset
    
    def preOrder(self, node):
        if node != None:
            print("tttt->%s" % node.data)
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)
    
    def search(self, tree, x):
        self.nearest_point = None
        self.nearest_value = None
        def travel(node, depth = 0):
            if node is not None:
                n = len(x)
                axis = depth % n
                if x[axis] < node.data[axis]:
                    travel(node.lchild, depth + 1)
                else:
                    travel(node.rchild, depth + 1)
                # 递归到最底下以后——要做的事情
                distance_Node_x = self.dist(x, node.data)
                
                # print('distance_Node_x', distance_Node_x)
                # print('node.data', node.data)
                
                if self.nearest_point is None or self.nearest_value > distance_Node_x:
                    self.nearest_point = node.data
                    self.nearest_value = distance_Node_x
                    
                # print('self.nearest_point', self.nearest_point)
                # print('self.nearest_value', self.nearest_value)
                # print(axis)
                if abs(x[axis] - node.data[axis]) <= self.nearest_value: # 这句是判断和轴的距离,看和另外一侧是否有交集
                    if x[axis] < node.data[axis]:
                        # x在这个轴上小于node.data,说明本身在node的左孩子,所以要去右孩子上看看
                        travel(node.rchild, depth + 1)
                    else:
                        travel(node.lchild, depth + 1)
        travel(tree)
        return self.nearest_point, self.nearest_value
    
    def dist(self, x_1, x_2):
        x_1 = np.array(x_1)
        x_2 = np.array(x_2)
        x = x_1 - x_2
        x = np.abs(x)
        x = np.power(x, 2)
        Sum = np.sum(x)
        x = np.power(Sum, 1/2)
        return x
dataset = [[2, 3],[5, 4],[9, 6],[4, 7], [8, 1], [7, 2]]
x = [5, 3]
kd_tree = KD_Tree()
head = kd_tree.create(dataset, 0)
# kd_tree.preOrder(head)
nearest_point, nearest_value = kd_tree.search(head, x)
print(nearest_point)
print(nearest_value)
[5, 4]
1.0

一位大神的复现看到《统计学习方法》的KD-tree python实现,本篇为记录贴。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值