《统计学习方法》KD树的构建与查找Python实现

关于KD树的构建和查找的理论可参考《统计学习方法》第三章以及这篇博文https://zhuanlan.zhihu.com/p/23966698,

import os
import csv
import numpy as np
import string
import pandas as pd
import operator
import re as re
import time    
import datetime

class KD_node():
	def __init__(self, vector, label, depth, dimension, l_node = None, r_node = None):
		self.val = vector
		self.label = label
		self.depth = depth
		self.dimension = dimension
		self.l_node = l_node
		self.r_node = r_node

	def add_lnode(self, new_node):
		self.l_node = new_node
	
	def add_rnode(self, new_node):
		self.r_node = new_node

	def print_preorder(self):
		print self.val
		if self.l_node:
			self.l_node.print_preorder()
		if self.r_node:
			self.r_node.print_preorder()

	def search(self, fea_vec, L, k):
		if self.val[self.dimension] < fea_vec[self.dimension]:#目标点在当前结点左边
			if self.l_node != None:
				L = self.l_node.search(fea_vec, L, k)  #若还有左结点则搜索左结点

			L = self.insertL(fea_vec, L, k) #对当前结点执行插入操作(包含了是否插入条件检测)
			if len(L) < k:  #若L仍然未满,则不用考虑,直接搜索右子树
				L = self.r_node.search(fea_vec, L, k)
			else if (self.val[self.dimension] - fea_vec[self.dimension])**2 < L[-1][1]: #判断右子树是否可能存在能插入L的结点,判断方法是求目标点到切割超平面的距离,此距离也就是切割维度上,切割点与目标点的距离,并将此距离与L内最大距离比较,若小于L内最大距离,则说明在切割超平面的另一边仍有可能存在能插入L的结点,我认为这一步,这个判断条件也是KD树提高搜索效率的关键所在
				L = self.r_node.search(fea_vec, L, k)

		else: #目标点在当前结点右边,跟上边的情况是对称的,就不写那么多注释了
			if self.r_node != None:
				L = self.r_node.search(fea_vec, L, k)
			L = self.insertL(fea_vec, L, k)
			if len(L) < k:
				L = self.l_node.search(fea_vec, L, k)
			else if (self.val[self.dimension] - fea_vec[self.dimension])**2 < L[-1][1]:
				L = self.l_node.search(fea_vec, L, k)
		return L


	def insertL(self, fea_vec, L, k):  #插入操作,检查当前结点是否可以插入L,若可以则插入,否则不做操作
		distance = sum((fea_vec - self.val)**2)
		if len(L) < k:  #如果L还没满,则直接插入
			L.append([self, distance])
			L = L[L[:, 1].argsort()]             #插入后排序,对Python不太熟,急于实现算法,用的效率比较低的方式,见谅。。。
		else if distance < L[-1][1]: #若L已满,判断L中最大距离是否比当前结点距离更大
			L[-1] = [self, distance]
			L = L[L[:, 1].argsort()]
                return L



def construct(data, depth):
	if len(data) == 0:
		return None
	dimension_sum = np.shape(data)[1] - 1 #特征维度,减1是因为label也占了一列
	dimension = depth % dimension_sum + 1 #切割轴维度
	data = data[data[:, dimension].argsort()] 
	node_data = data[len(data)/2]
	new_node = KD_node(node_data[1:], node_data[0], depth, dimension)
	new_node.l_node = construct(data[:len(data)/2], depth + 1)
	new_node.r_node = construct(data[len(data)/2 + 1:], depth + 1)
	return new_node

class KD_Tree():  #与一般的二叉排序树结构差不多,区别在于KD树需要反复使用各个维度来比较以构造二叉树
	# def __init__(self, data_root):
	# 	data_file = pd.read_csv(data_root, header = None)
	# 	self.data =	np.array(data_file)[:, 2:]
	# 	self.label = np.array(data_file)[:, 1]

	def __init__(self, data):
		self.data = data[:,:2]
		self.label = data[:, 2]


	def constructor(self):
		data = np.c_[self.label, self.data]  #合并特征向量和label
		data = data[data[:, 1].argsort()]  #按第一维排序(第0维是label)
 		self.root = construct(data, 0)

 	def print_preorder(self):
 		print 'preorder'
 		self.root.print_preorder()

 	def predict(self, fea_vec, k = 3):
 		L = []      #存储当前检测到的距离最近的K个点,初始为空,L中的元素为[KD_node, distance]
 		self.root(fea_vec, L, k)
 		label_dict = {}
 		for each_sample in L:
 			label_dict[each_sample.label] += 1
 		print(sorted(label_dict,key=lambda x:label_dict[x])[-1])  #打印出k个最近邻结点中个数最多的类别
代码还没有调试,思路应该没错,明天调好再更新一下。。。。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值