kd树 python实现_统计学习方法第三章:k近邻法(k-NN),kd树及python实现

欢迎关注公众号:常失眠少年,大学生的修炼手册!

k近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法。

k近邻法假设给定一个训练数据集,其中的实例类别已定。分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方式进行预测。因此,k邻近法不具有显式的学习过程。

k近邻法实际上利用训练数据集对特征空间进行划分,并作为其分类的“模型”。

k值的选择,距离度量及分类决策规则是k近邻法的三个基本要素。

下图是k近邻法:

15f28da3969b

k近邻法

实现k近邻法时,主要考虑的问题是如何对训练数据进行快速k近邻搜索。这点在特征空间的维数大以及训练数据容量大时尤其必要

k近邻法最简单的实现方式是线性扫描。这时要计算输入实例与每一个训练实例的距离,当训练集很大时,计算非常耗时,这种方法是不可行的

为了提高k近邻搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数。其中一种方法就是kd树方法

下图是kd树的构造算法:

15f28da3969b

kd树构造算法

下图是kd树的搜索算法:

15f28da3969b

kd树搜索算法

更具体的解释和证明可以看《统计学习方法》或者其他解释kd树的博文,我在这里不再赘述

下面是python代码实现,使用MINST数据集,构造kd树进行搜索,实现的是最近邻算法,即只搜寻最近的一个实例来决定类别

但有一个问题是运算很慢,我也不得其解,但算法核心部分实现应当是无误的

import pandas as pd

import numpy as np

import cv2

import logging

import time

from math import sqrt

from collections import namedtuple

from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score

def log(func):

def wrapper(*args, **kwargs):

start_time = time.time()

logging.debug('start %s()' % func.__name__)

ret = func(*args, **kwargs)

end_time = time.time()

logging.debug('end %s(), cost %s seconds' % (func.__name__, end_time - start_time))

return ret

return wrapper

def get_hog_features(trainset):

# 利用opencv获取图像hog特征

features = []

hog = cv2.HOGDescriptor('../hog.xml')

for img in trainset:

img = np.reshape(img, (28, 28))

cv_img = img.astype(np.uint8)

hog_feature = hog.compute(cv_img)

# hog_feature = np.transpose(hog_feature)

features.append(hog_feature)

features = np.array(features)

features = np.reshape(features, (-1, 324))

return features

def predict(test_set, kd_tree):

predict = []

for i in range(len(test_set)):

predict.append(find_nearest(kd_tree, test_set[i]).label)

return np.array(predict)

class KdNode(object):

def __init__(self, dom_elt, split, left, right, label):

self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)

self.split = split # 整数(进行分割维度的序号)

self.left = left # 该结点分割超平面左子空间构成的kd-tree

self.right = right # 该结点分割超平面右子空间构成的kd-tree

self.label = label

class KdTree(object):

@log

def __init__(self, data, labels):

k = len(data[0]) # 数据维度

def create_node(split, data_set, labels): # 按第split维划分数据集,创建KdNode

# print(len(data_set))

if (len(data_set) == 0):

return None

sort_index = data_set[:, split].argsort()

data_set = data_set[sort_index]

labels = labels[sort_index]

# print(data_set)

split_pos = len(data_set) // 2

# print(split_pos)

median = data_set[split_pos] # 中位数分割点

label = labels[split_pos]

split_next = (split + 1) % k # cycle coordinates

# 递归的创建kd树

return KdNode(median, split,

create_node(split_next, data_set[:split_pos], labels[:split_pos]), # 创建左子树

create_node(split_next, data_set[split_pos + 1:], labels[split_pos + 1:]), # 创建右子树

label)

self.root = create_node(0, data, labels) # 从第0维分量开始构建kd树,返回根节点

# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数

result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited label")

@log

def find_nearest(tree, point):

k = len(point) # 数据维度

def travel(kd_node, target, max_dist):

if kd_node is None:

return result([0] * k, float("inf"), 0, 0) # python中用float("inf")和float("-inf")表示正负无穷

nodes_visited = 1

s = kd_node.split # 进行分割的维度

pivot = kd_node.dom_elt # 进行分割的“轴”

if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)

nearer_node = kd_node.left # 下一个访问节点为左子树根节点

further_node = kd_node.right # 同时记录下右子树

else: # 目标离右子树更近

nearer_node = kd_node.right # 下一个访问节点为右子树根节点

further_node = kd_node.left

if (nearer_node is None):

label = 0

else:

label = nearer_node.label

temp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域

nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”

dist = temp1.nearest_dist # 更新最近距离

nodes_visited += temp1.nodes_visited

if dist < max_dist:

max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内

temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离

if max_dist < temp_dist: # 判断超球体是否与超平面相交

return result(nearest, dist, nodes_visited, temp1.label) # 不相交则可以直接返回,不用继续判断

# ----------------------------------------------------------------------

# 计算目标点与分割点的欧氏距离

temp_dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target)))

if temp_dist < dist: # 如果“更近”

nearest = pivot # 更新最近点

dist = temp_dist # 更新最近距离

max_dist = dist # 更新超球体半径

label = kd_node

# 检查另一个子结点对应的区域是否有更近的点

temp2 = travel(further_node, target, max_dist)

nodes_visited += temp2.nodes_visited

if temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离

nearest = temp2.nearest_point # 更新最近点

dist = temp2.nearest_dist # 更新最近距离

label = temp2.label

return result(nearest, dist, nodes_visited, label)

return travel(tree.root, point, float("inf")) # 从根节点开始递归

k = 10

if __name__ == '__main__':

logger = logging.getLogger()

logger.setLevel(logging.DEBUG)

raw_data = pd.read_csv('../data/train.csv', header=0)

data = raw_data.values

images = data[0:, 1:]

labels = data[:, 0]

features = get_hog_features(images)

# 选取 2/3 数据作为训练集, 1/3 数据作为测试集

train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33,random_state=1)

kd_tree = KdTree(train_features, train_labels)

test_predict = predict(test_features, kd_tree)

score = accuracy_score(test_labels, test_predict)

print("The accuracy score is ", score)

水平有限,如有错误,希望指出

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值