KDTree

 

import math
import numpy as np


class KD_node:  
    def __init__(self, point=None, split=None, LL = None, RR = None):  
        self.point = point  
        self.split = split  
        self.left = LL  
        self.right = RR      
def createKDTree(root, data_list):  
    LEN = len(data_list)  
    if LEN == 0:  
        return    
    dimension = len(data_list[0])  
    max_var = 0  
    split = 0;  
    for i in range(dimension):  
        ll = []  
        for t in data_list:  
            ll.append(t[i])  
        var = computeVariance(ll)  
        if var > max_var:  
            max_var = var  
            split = i  
    data_list.sort(key=lambda x: x[split])  
    point = data_list[int(LEN / 2)]  
    root = KD_node(point, split)  
    root.left = createKDTree(root.left, data_list[0:int(LEN / 2)])  
    root.right = createKDTree(root.right, data_list[int(LEN / 2 + 1):LEN])  
    return root  
  
  
def computeVariance(arrayList):  
    LEN = len(arrayList)  
    array = np.array(arrayList)  
    sum1 = array.sum()  
    array2 = array * array  
    sum2 = array2.sum()  
    mean = sum1 / LEN  
    variance = sum2 / LEN - mean**2  
    return variance 
def findNN(root, query):  
    NN = root.point  
    min_dist = computeDist(query, NN)  
    nodeList = []  
    temp_root = root  
    while temp_root:  
        nodeList.append(temp_root)  
        dd = computeDist(query, temp_root.point)  
        if min_dist > dd:  
            NN = temp_root.point  
            min_dist = dd  
        ss = temp_root.split  
        if query[ss] <= temp_root.point[ss]:  
            temp_root = temp_root.left  
        else:  
            temp_root = temp_root.right  
    while nodeList:  
        back_point = nodeList.pop()  
        ss = back_point.split  
        print( "back.point = ", back_point.point  )
        if abs(query[ss] - back_point.point[ss]) < min_dist:  
            if query[ss] <= back_point.point[ss]:  
                temp_root = back_point.right  
            else:  
                temp_root = back_point.left  
            if temp_root:  
                nodeList.append(temp_root)  
                curDist = computeDist(query, temp_root.point)  
                if min_dist > curDist:  
                    min_dist = curDist  
                    NN = temp_root.point  
    return NN, min_dist  
  
  
def computeDist(pt1, pt2):  
    sum = 0.0  
    for i in range(len(pt1)):  
        sum = sum + (pt1[i] - pt2[i]) * (pt1[i] - pt2[i])  
    return math.sqrt(sum)  

if __name__=="__main__":
    root=KD_node()
    l=[[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    t1=createKDTree(root,l)
    print(findNN(t1,[2.1,3.1]))

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值