import numpy as np
import matplotlib.pyplot as plt
def distance(x1, x2):
tmp = 0
for i in range(dimension):
tmp += np.square(x1[i] - x2[i])
return np.sqrt(tmp)
class Node:
def __init__(self, key, depth=0, left=None, right=None):
self.depth = depth
self.left = left
self.right = right
self.key = key
class SDTree:
def __init__(self):
self.root = None
self.nearest = None
def create(self, dataset, depth=0):
if len(dataset) > 0:
median = len(dataset) // 2
axis = depth % dimension
copy = sorted(dataset, key=lambda x: x[axis])
node = Node(copy[median], depth)
node.left = self.create(copy[:median], depth+1)
node.right = self.create(copy[median+1:], depth+1)
if depth == 0:
self.root = node
return node
return None
def display(self, node):
if node is not None:
self.display(node.left)
print(node.depth, node.key)
self.display(node.right)
def show(self):
for item in self.nearest:
print(item[0], item[1].key)
def find(self, x, count=1):
node_set = []
for i in range(count):
node_set.append([-1, None])
self.nearest = np.array(node_set) # 转换成 darray 是因为 list 不好插入
def recurve(node):
if node is None:
return
axis = node.depth % dimension
dis = node.key[axis] - x[axis]
if dis >= 0:
recurve(node.left)
else:
recurve(node.right)
d = np.sqrt(sum(np.square(p - q) for (p, q) in zip(x, node.key)))
for ii, item in enumerate(self.nearest): # 小白觉得这么写好厉害
if item[0] < 0 or d < item[0]:
self.nearest = np.insert(self.nearest, ii, [d, node], axis=0) # 保证留下来的是最近的
self.nearest = self.nearest[:-1] # 只留 count 个最近点
break
n = list(self.nearest[:, 0]).count(-1) # 转换成 list 是因为 darray 没有count方法
if self.nearest[-n-1, 0] > abs(dis):
if dis >= 0:
recurve(node.right)
else:
recurve(node.left)
recurve(self.root)
S = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
dimension = 2
tree = SDTree()
tree.create(S)
X = np.array([7, 1])
tree.find(X, 3)
tree.show()
# tree.display(tree.root)
xx, yy = [2, 5, 9, 4, 8, 7], [3, 4, 6, 7, 1, 2]
XX, YY = [], []
for it in tree.nearest:
XX.append(it[1].key[0])
YY.append(it[1].key[1])
plt.scatter(xx, yy, label='init dots', marker='o')
plt.scatter(XX, YY, label='target dots', marker='x')
plt.scatter(7, 1, label='test dot', marker='^')
plt.legend()
plt.savefig("result.png")
统计学习方法例3.2