最近在看机器学习基础,差一点就放弃写了,给我写吐了。大概花了4个小时吧,终于能运行了,并且对了。
哎,递归有时候,走大运就很容易写对,倒霉的话,就一直写不对。yue!!!
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# 构造KD树
# 教学视频: https://www.bilibili.com/video/BV1JE41147KA?p=2&spm_id_from=pageDriver
import math
class Node:
def __init__(self, data, left, right, parent):
self.data = data
self.left = left
self.right = right
self.parent = parent
def get_median(data):
size = len(data)
median = -1
if size % 2 == 0: # 判断列表长度为偶数
median = size / 2
if size % 2 == 1: # 判断列表长度为奇数
median = (size - 1) / 2
return int(median)
def make_KD_tree(root, dimension, median, data):
if len(data) == 0:
root.right = None
root.left = None
return
left_data_list = data[:median]
if len(left_data_list) == 0:
root.left = None
else:
# 根据维度排序
if dimension == 'y':
left_data_list.sort(key=lambda x: (x[1], x[0]))
else:
left_data_list.sort(key=lambda x: (x[0], x[1]))
left_data = left_data_list[get_median(left_data_list)]
left_node = Node(left_data, None, None, root)
root.left = left_node
left_data_list.remove(left_data)
make_KD_tree(left_node, 'x' if dimension == 'y' else 'y', get_median(left_data_list), left_data_list)
right_data_list = data[median:]
if len(right_data_list) == 0:
root.right = None
else:
# 根据y排序
if dimension == 'y':
right_data_list.sort(key=lambda x: (x[1], x[0]))
else:
# 根据x排序
right_data_list.sort(key=lambda x: (x[0], x[1]))
right_data = right_data_list[get_median(right_data_list)]
right_node = Node(right_data, None, None, root)
root.right = right_node
right_data_list.remove(right_data)
make_KD_tree(right_node, 'x' if dimension == 'y' else 'y', get_median(right_data_list), right_data_list)
def find_nearest(root, test):
temp = root
path = []
while temp is not None:
path.append(temp)
if test == temp.data:
return [temp]
elif test < root.data:
temp = temp.left
else:
temp = temp.right
return path
if __name__ == '__main__':
students = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]
# 根据x排序
students.sort(key=lambda x: (x[0], x[1]))
median = get_median(students)
root_data = students[median]
students.remove(root_data)
root = Node(root_data, None, None, None)
make_KD_tree(root, 'y', median, students)
# print(students.index((4, 7), 0, len(students)))
test = (2, 4.5)
path = find_nearest(root, test)
leaf = path[-1]
nearest = math.dist(leaf.data, test)
nearest_node = leaf
while leaf.parent is not None:
distance = math.dist(leaf.parent.data, test)
if distance < nearest:
nearest_node = leaf.parent
nearest = distance
leaf = leaf.parent
print('nearest', nearest, 'nearest_node', nearest_node.data)