kd树 python实现_Python KD树实现+简单的KNN实现

写KD树的时候没把类别考虑进去。。。所以先用KD算出最近的k个点,然后找到对应分类最后输出占比最大的

KD树是一种二叉树,用来分割空间上得点

一个树节点的结构如下:

class TreeNode:

index = -1 # 对应维度序号

point = None # 对应的点

left = None # 左子树

right = None # 右子树

data = None

def __init__(self, index=-1, point=None, left=None, right=None):

self.index = index

self.point = point

self.left = left

self.right = right

def set_data(self, data):

self.data = data

def get_data(self):

return self.data

建树过程是:

先选出方差最大的维度

将现有数据按该维度排序

取数据中位点

中位点即该树结点的数据

点坐标左边的的传入左子树构造方法,右边的同理

下一层树结点使用下一个维度

代码:

def build_tree(self, dataset, split):

# 如果为空返回None

if dataset is None or len(dataset) is 0:

return None

# 顺序维度超出维度范围取余

if split >= len(dataset[0]) - 1:

split %= len(dataset[0]) - 1

# 如果仅只有一个点那么必定是叶子

if len(dataset) is 1:

return TreeNode(split, dataset[0], None, None)

data_sum = len(dataset)

dataset.sort(key=lambda x: x[split])

node = TreeNode()

node.index = split

point_index = int(data_sum / 2)

node.point = dataset[point_index]

node.left = self.build_tree(dataset[0:point_index], split + 1)

node.right = self.build_tree(dataset[point_index + 1:], split + 1)

return node

def create(self, dataset):

starlin = self.get_var(dataset)

root = self.build_tree(dataset, starlin)

self.root = root

return root

有时候check后的数据需要插入到树中:

插入的过程较简单,从root开始 按该层维度,大于该层维度的值的继续搜索左子树,反之右子树

直到搜索的节点为None 则在这里插入新的结点

def insert(self, point):

if self.root is None:

print('Build a tree first !')

return

if len(point) is not len(self.root.point):

print('This point have {l} splits but tree have {m}'.format(l=len(point), m=len(self.root.point)))

return

flag = False

root = self.root

while not flag:

if point[root.index] < root.point[root.index]:

if root.left is not None:

root = root.left

else:

split = (root.index + 1) % len(point)

root.left = TreeNode(split, point, None, None)

flag = True

else:

if root.right is not None:

root = root.right

else:

split = (root.index + 1) % len(point)

root.right = TreeNode(split, point, None, None)

flag = True

寻找过程,首先先按照插入的方法找到最接近的最底层子节点

然后依次向上回溯查找,如果该结点的另半个子树也可能成为最近点则将其Push进栈

查找至栈为空,则找到最近点。点间距离同理可应用不同的距离(相似度)算法

def sim_distance(self, p1, p2):

sum_of_squares = sum([pow(p1[i] - p2[i], 2) for i in range(len(p1))])

return sqrt(sum_of_squares)

def find_nearest(self, point):

root = self.root

s = Stack(99999)

while root is not None:

index = root.index

s.push(root)

if point[index] <= root.point[index]:

root = root.left

else:

root = root.right

nearest = s.pop()

min_dist = self.sim_distance(nearest.point, point)

while not s.isempty():

back_point = s.pop()

if back_point is None:

continue

index = back_point.index

if self.sim_distance([point[index]], [back_point.point[index]]) < min_dist:

if point[index] <= back_point.point[index]:

root = back_point.right

else:

root = back_point.left

s.push(root)

if min_dist > self.sim_distance(back_point.point, point):

nearest = back_point

min_dist = self.sim_distance(back_point.point, point)

return nearest.point, min_dist

KNN 算法的核心在于找到最近的k的点,然后根据这些点的类别缺点待查点的类别

我维护了一个长度始终为k的list来保存前k小得距离

每次跟 list尾部的进行比较,如果比其小则加入list,并排序 取前k项

def find_near_kth(self, point, k):

root = self.root

result = []

s = Stack(99999)

while root is not None:

index = root.index

s.push(root)

if point[index] <= root.point[index]:

root = root.left

else:

root = root.right

t_point = s.pop()

result.append((t_point, self.sim_distance(t_point.point, point)))

while not s.isempty():

back_point = s.pop()

if back_point is None:

continue

index = back_point.index

if self.sim_distance([point[index]], [back_point.point[index]]) <= result[len(result) - 1][1] or len(

result) < k:

if point[index] <= back_point.point[index]:

root = back_point.right

else:

root = back_point.left

s.push(root)

if result[len(result) - 1][1] > self.sim_distance(back_point.point, point) or len(result) < k:

result.append((back_point, self.sim_distance(back_point.point, point)))

result.sort(key=lambda x: x[1])

result = result[0:k]

return result

最后用了很蠢得方法来找对应点的分类:

def decide_type(kd_result, t_point, t_type):

ans = {i: 0 for i in t_type}

for node in kd_result:

for i in range(len(t_point)):

if node[0].point == t_point[i]:

ans[t_type[i]] += 1

break

max_v = 0

max_type = None

for i in ans:

if ans[i] > max_v:

max_v = ans[i]

max_type = i;

return max_type

测试如下:

kd = KdTree()

kd.create(train_point)

print(kd.find_near_kth((1, 1), 2))

# print(decide_type(kd.find_near_kth((6.5, 6), 3),train_point,train_type))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值