kd树 python实现_k近邻的kd树构造与搜索python实现

import numpy as np

import matplotlib.pyplot as plt

def distance(x1, x2):

tmp = 0

for i in range(dimension):

tmp += np.square(x1[i] - x2[i])

return np.sqrt(tmp)

class Node:

def __init__(self, key, depth=0, left=None, right=None):

self.depth = depth

self.left = left

self.right = right

self.key = key

class SDTree:

def __init__(self):

self.root = None

self.nearest = None

def create(self, dataset, depth=0):

if len(dataset) > 0:

median = len(dataset) // 2

axis = depth % dimension

copy = sorted(dataset, key=lambda x: x[axis])

node = Node(copy[median], depth)

node.left = self.create(copy[:median], depth+1)

node.right = self.create(copy[median+1:], depth+1)

if depth == 0:

self.root = node

return node

return None

def display(self, node):

if node is not None:

self.display(node.left)

print(node.depth, node.key)

self.display(node.right)

def show(self):

for item in self.nearest:

print(item[0], item[1].key)

def find(self, x, count=1):

node_set = []

for i in range(count):

node_set.append([-1, None])

self.nearest = np.array(node_set) # 转换成 darray 是因为 list 不好插入

def recurve(node):

if node is None:

return

axis = node.depth % dimension

dis = node.key[axis] - x[axis]

if dis >= 0:

recurve(node.left)

else:

recurve(node.right)

d = np.sqrt(sum(np.square(p - q) for (p, q) in zip(x, node.key)))

for ii, item in enumerate(self.nearest): # 小白觉得这么写好厉害

if item[0] < 0 or d < item[0]:

self.nearest = np.insert(self.nearest, ii, [d, node], axis=0) # 保证留下来的是最近的

self.nearest = self.nearest[:-1] # 只留 count 个最近点

break

n = list(self.nearest[:, 0]).count(-1) # 转换成 list 是因为 darray 没有count方法

if self.nearest[-n-1, 0] > abs(dis):

if dis >= 0:

recurve(node.right)

else:

recurve(node.left)

recurve(self.root)

S = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]

dimension = 2

tree = SDTree()

tree.create(S)

X = np.array([7, 1])

tree.find(X, 3)

tree.show()

# tree.display(tree.root)

xx, yy = [2, 5, 9, 4, 8, 7], [3, 4, 6, 7, 1, 2]

XX, YY = [], []

for it in tree.nearest:

XX.append(it[1].key[0])

YY.append(it[1].key[1])

plt.scatter(xx, yy, label='init dots', marker='o')

plt.scatter(XX, YY, label='target dots', marker='x')

plt.scatter(7, 1, label='test dot', marker='^')

plt.legend()

plt.savefig("result.png")

6cce85032239

统计学习方法例3.2

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值