先写一个util.py
import math
def distance(x1, x2):
dist = 0
for i in range(len(x1)):
dist += (x1[i]-x2[i])**2
return math.sqrt(dist)
def sort( X, k, efrom, eto):
for i in range(efrom, eto):
for j in range(i+1,eto+1):
if X[i][k] > X[j][k]:
X[i], X[j] = X[j], X[i]
再写个KDTree的生成,kdt.py
import util
class Node:
def __init__(self, _layer):
self.layer = _layer
self.elem = None
def kd(self, X, layer, efrom, eto, k):
#
util.sort(X, layer, efrom, eto)
median = efrom + (eto - efrom + 1) // 2
self.elem = X[median]
layer += 1
layer %= k
if median > efrom:
self.lchild = Node(layer)
self.lchild.kd(X, layer, efrom, median - 1, k)
if median < eto:
self.rchild = Node(layer)
self.rchild.kd(X, layer, median + 1, eto, k)
class KDTree:
def __init__(self, _k):
self.k = _k
def growing(self, X):
root = Node(0)
root.kd(X, 0, 0, len(X) - 1, self.k)
print (root.rchild.elem)
测试一下
import kdt
k = 2
X = [[2, 3],[5,4],[9,6], [4, 7], [8,1], [7,2]]
tree = kdt.KDTree(k)
tree.growing(X)
print(X)
下面看怎么找最近邻节点和KNN算法。