kd树介绍(KNN算法引出)

kd 树的结构

kd树是一个二叉树结构,它的每一个节点记载了 [特征坐标, 切分轴, 指向左枝的指针, 指向右枝的指针] 。 其中, 特征坐标是线性空间 R n \mathbb{R}^{n} Rn 中的一个点 ( x 1 , x 2 , … , x n ) ∘ \left(x_{1}, x_{2}, \ldots, x_{n}\right)_{\circ} (x1,x2,,xn)

切分轴由一个整数 r r r 表示, 这里 1 ≤ r ≤ n , 1 \leq r \leq n, 1rn, 是我们在 n n n 维空间中沿第 r r r 维进行一次分割。 节点的左枝和右枝分别都是 kd 树, 并且满足:如果 y y y 是左枝的一个特征坐标, 那么 y r ≤ x r ; y_{r} \leq x_{r} ; yrxr; 并且如果 z z z 是右 枝的一个特征坐标,那么 z r ≥ x r ∘ z_{r} \geq x_{r \circ} zrxr

给定一个数据样本集 S ⊆ R n S \subseteq R^{n} SRn 和切分轴 r , r, r, 以下递归算法将构建一个基于该数据集的 kd 树, 每一次循环制作一 个节点:

  • 如果 ∣ S ∣ = 1 , |S|=1, S=1, 记录 S S S 中唯一的一个点为当前节点的特征数据, 并且不设左枝和右枝。 ( ∣ S ∣ \quad(|S| (S 指集合 S S S 中元素 ) ) ). 的数量) − - 如果 ∣ S ∣ > 1 : |S|>1: S>1:

  • 如果 ∣ S ∣ > 1 : |S|>1: S>1:

    • 将 S 内所有点按照第 r r r 个坐标的大小进行排序;
    • 选出该排列后的中位元素 (如果一共有偶数个元素, 则选择中位左边或右边的元素, 左或右并无影响),作为当前节点的特征坐标, 并且记录切分轴 r r r
      ∙ \bullet S L S_{L} SL 设为在 S S S 中所有排列在中位元素之前的元素; S R \quad S_{R} SR 设为在 S S S 中所有排列在中位元素后的元素;
    • 当前节点的左枝设为以 S L S_{L} SL 为数据集并且 r r r 为切分轴制作出的 k d \mathrm{kd} kd 树; 当前节点的右枝设为以 S R S_{R} SR 为数据集并且 r r r 为切分轴制作出 的 kd 树。再设 r ← ( r + 1 )   m o d   n ∘ ( r \leftarrow(r+1) \quad \bmod n_{\circ} \quad( r(r+1)modn( 这里, 我们想轮流沿着每一个维度进 行分割;   m o d   n \quad \bmod n modn 是因为一共有 n n n 个维度, 在 沿着最后一个维度进行分割之后再重新回到第一个维度。)

构造 kd 树的例子

上面抽象的定义和算法确实是很不好理解,举一个例子会清楚很多。首先随机在 R 2 \mathbb{R}^{2} R2 中随机生成 13 个点作为我们的数据集。起始的切分轴 r = 0 ; r=0 ; r=0; 这里 r = 0 r=0 r=0 对应 x x x 轴, 而 r = 1 r=1 r=1 对应 y y y 轴。
在这里插入图片描述
首先先沿 x 坐标进行切分,我们选出 x 坐标的中位点,获取最根部节点的坐标
在这里插入图片描述

并且按照该点的x坐标将空间进行切分,所有 x 坐标小于 6.27 的数据用于构建左枝,x坐标大于 6.27 的点用于构建右枝。
在这里插入图片描述
 在下一步中  r = 0 + 1 = 1  mod  2  对应  y  轴, 左右两边再按照  y  轴的排序进行切分,中位点记裁于左右枝的  \text { 在下一步中 } r=0+1=1 \quad \text { mod } 2 \text { 对应 } y \text { 轴, 左右两边再按照 } y \text { 轴的排序进行切分,中位点记裁于左右枝的 }  在下一步中 r=0+1=1 mod 2 对应 y 左右两边再按照 y 轴的排序进行切分,中位点记裁于左右枝的 节点。得到下面的树,左边的x 是指这该层的节点都是沿 x 轴进行分割的。
在这里插入图片描述
空间的切分如下
在这里插入图片描述
下一步中 r ≡ 1 + 1 ≡ 0   m o d   2 , r \equiv 1+1 \equiv 0 \quad \bmod 2, r1+10mod2, 对应 x x x 轴, 所以下面再按照 x x x 除标进行排序和切分,有
在这里插入图片描述
在这里插入图片描述
最后每一部分都只剩一个点,将他们记在最底部的节点中。因为不再有未被记录的点,所以不再进行切分。
在这里插入图片描述
在这里插入图片描述
就此完成了 kd 树的构造。

kd 树上的 kNN 算法

给定一个构建于一个样本生的 kd 树, 下面的算法可以寻找距离某个点 p p p 最近的 k k k 个样本。

  1. L L L 为一个有 k k k 个空位的列表, 用于保存已搜寻到的最近点.

  2. 根据 p p p 的坐标值和每个节点的切分向下搜素(也就是选,如果树的节点是按照 x r = a x_{r}=a xr=a 进行切分,并且 p p p r r r 坐标小于 a , a, a, 则向左枝进行搜索: 反之则走右枝)。

  3. 当达到一个底部节点时,将其标记为访问过. 如果 L L L 里不足 k k k 个点. 则将当前节点的特征坐标加人 L : L: L: 如 果 L不为空并且当前节点 \quad 的特征与 p p p 的距离小于 L L L 里最长的距离,则用当前特征音换掉 L L L 中离 p p p 最远的点

  4. 如果当前节点不是整棵树最顶而节点, 执行 下(1):反之. 输出 L , L, L, 算法完成.
    (1) . 向上爬一个节点。如果当前 (向上爬之后的) 节点未管被访问过, 将其标记为被访问过, 然后执行 1和2:如果当前节点被访 问过, 再次执行 (1)。

    1. 如果此时 L L L 里不足 k k k 个点, 则将节点特征加入 L : L: L: 如果 L L L 中已满 k k k 个点, 且当前节点与 p p p 的距离小于 L L L 里最长的距离。 则用节点特征豐换掉 L L L 中帝最远的点。
    2. 计算 p p p 和当前节点切分綫的距离。如果该距离大于等于 L L L 中距离 p p p 最远的距离井且 L L L 中已有 k k k 个点。则在切分线另一边不会有更近的点, 执行3: 如果该距离小于 L L L 中最远的距离或者 L L L 中不足 k k k 个点, 则切分綫另一边可能有更近的点, 因此在当前节点的另一个枝从 1 1 1 开始执行.

来看下面的例子:
在这里插入图片描述
首先执行1,我们按照切分找到最底部节点。首先,我们在顶部开始
在这里插入图片描述
和这个节点的 x轴比较一下,
在这里插入图片描述
ppp 的 x 轴更小。因此我们向左枝进行搜索:
在这里插入图片描述
这次对比 y 轴,
在这里插入图片描述
p 的 y 值更小,因此向左枝进行搜索:
在这里插入图片描述
这个节点只有一个子枝,就不需要对比了。由此找到了最底部的节点 (−4.6,−10.55)。
在这里插入图片描述
在二维图上是
在这里插入图片描述
此时我们执行2。将当前结点标记为访问过, 并记录下 L = [ ( − 4.6 , − 10.55 ) ] . L=[(-4.6,-10.55)] . L=[(4.6,10.55)]. 访问过的节点就在二叉树 上显示为被划掉的好了。

然后执行 3,不是最顶端节点。执行 (1),我爬。上面的是 (−6.88,−5.4)。
在这里插入图片描述在这里插入图片描述

执行 1,因为我们记录下的点只有一个,小于k=3,所以也将当前节点记录下,有 L=[(−4.6,−10.55),(−6.88,−5.4)]。再执行 2,因为当前节点的左枝是空的,所以直接跳过,回到步骤3。3看了一眼,好,不是顶部,交给你了,(1)。于是乎 (1) 又往上爬了一节。
在这里插入图片描述
在这里插入图片描述
1 说,由于还是不够三个点,于是将当前点也记录下,有 L=[(−4.6,−10.55),(−6.88,−5.4),(1.24,−2.86)。当然,当前结点变为被访问过的。

2又发现,当前节点有其他的分枝,并且经计算得出 p 点和 L 中的三个点的距离分别是 6.62,5.89,3.10,但是 p 和当前节点的分割线的距离只有 2.14,小于与 L 的最大距离:
在这里插入图片描述
因此,在分割线的另一端可能有更近的点。于是我们在当前结点的另一个分枝从头执行 1。好,我们在红线这里:
在这里插入图片描述
要用 p 和这个节点比较 x 坐标:
在这里插入图片描述
p 的x 坐标更大,因此探索右枝 (1.75,12.26),并且发现右枝已经是最底部节点,因此启动 2。
在这里插入图片描述
经计算,(1.75,12.26)与 p 的距离是 7.48,要大于 p 与 L 的距离,因此我们不将其放入记录中。
在这里插入图片描述
然后 3 判断出不是顶端节点,呼出 (1),爬。
在这里插入图片描述
1出来一算,这个节点与 p 的距离是 4.91,要小于 p 与 L 的最大距离 6.62。
在这里插入图片描述
因此,我们用这个新的节点替代 L 中离 p 最远的 (−4.6,−10.55)。
在这里插入图片描述
然后 2又来了,我们比对 p 和当前节点的分割线的距离
在这里插入图片描述
这个距离小于 L 与 p 的最小距离,因此我们要到当前节点的另一个枝执行 1。当然,那个枝只有一个点,直接到 2。
在这里插入图片描述
计算距离发现这个点离 p 比 L 更远,因此不进行替代。
在这里插入图片描述
3发现不是顶点,所以呼出 (1)。我们向上爬,
在这里插入图片描述
这个是已经访问过的了,所以再来(1),
在这里插入图片描述
好,(1)再爬,
在这里插入图片描述
啊!到顶点了。所以完了吗?当然不,还没轮到 3 呢。现在是 1的回合。

我们进行计算比对发现顶端节点与p的距离比L还要更远,因此不进行更新。
在这里插入图片描述
然后是 2,计算 p 和分割线的距离发现也是更远。
在这里插入图片描述
因此也不需要检查另一个分枝。

然后执行 3,判断当前节点是顶点,因此计算完成!输出距离 p 最近的三个样本是 L=[(−6.88,−5.4),(1.24,−2.86),(−2.96,−2.5)].

C实现

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <float.h>
#include <math.h>

#include "kdtree.h"

static inline int is_leaf(struct kdnode *node)
{
        return node->left == node->right;
}

static inline void swap(long *a, long *b)
{
        long tmp = *a;
        *a = *b;
        *b = tmp;
}

static inline double square(double d)
{
        return d * d;
}

static inline double distance(double *c1, double *c2, int dim)
{
        double distance = 0;
        while (dim-- > 0) {
                distance += square(*c1++ - *c2++);
        }
        return distance;
}

static inline double knn_max(struct kdtree *tree)
{
        return tree->knn_list_head.prev->distance;
}

static inline double D(struct kdtree *tree, long index, int r)
{
        return tree->coord_table[index][r];
}

static inline int kdnode_passed(struct kdtree *tree, struct kdnode *node)
{
        return node != NULL ? tree->coord_passed[node->coord_index] : 1;
}

static inline int knn_search_on(struct kdtree *tree, int k, double value, double target)
{
        return tree->knn_num < k || square(target - value) < knn_max(tree);
}

static inline void coord_index_reset(struct kdtree *tree)
{
        long i;
        for (i = 0; i < tree->capacity; i++) {
                tree->coord_indexes[i] = i;
        }
}

static inline void coord_table_reset(struct kdtree *tree)
{
        long i;
        for (i = 0; i < tree->capacity; i++) {
                tree->coord_table[i] = tree->coords + i * tree->dim;
        }
}

static inline void coord_deleted_reset(struct kdtree *tree)
{
        memset(tree->coord_deleted, 0, tree->capacity);
}

static inline void coord_passed_reset(struct kdtree *tree)
{
        memset(tree->coord_passed, 0, tree->capacity);
}

static void coord_dump_all(struct kdtree *tree)
{
        long i, j;
        for (i = 0; i < tree->count; i++) {
                long index = tree->coord_indexes[i];
                double *coord = tree->coord_table[index];
                printf("(");
                for (j = 0; j < tree->dim; j++) {
                        if (j != tree->dim - 1) {
                                printf("%.2f,", coord[j]);
                        } else {
                                printf("%.2f)\n", coord[j]);
                        }
                }
        }
}

static void coord_dump_by_indexes(struct kdtree *tree, long low, long high, int r)
{
        long i;
        printf("r=%d:", r);
        for (i = 0; i <= high; i++) {
                if (i < low) {
                        printf("%8s", " ");
                } else {
                        long index = tree->coord_indexes[i];
                        printf("%8.2f", tree->coord_table[index][r]);
                }
        }
        printf("\n");
}

static void bubble_sort(struct kdtree *tree, long low, long high, int r)
{
        long i, flag = high + 1;
        long *indexes = tree->coord_indexes;
        while (flag > 0) {
                long len = flag;
                flag = 0;
                for (i = low + 1; i < len; i++) {
                        if (D(tree, indexes[i], r) < D(tree, indexes[i - 1], r)) {
                                swap(indexes + i - 1, indexes + i);
                                flag = i;
                        }
                }
        }
}

static void insert_sort(struct kdtree *tree, long low, long high, int r)
{
        long i, j;
        long *indexes = tree->coord_indexes;
        for (i = low + 1; i <= high; i++) {
                long tmp_idx = indexes[i];
                double tmp_value = D(tree, indexes[i], r);
                j = i - 1;
                for (; j >= low && D(tree, indexes[j], r) > tmp_value; j--) {
                        indexes[j + 1] = indexes[j];
                }
                indexes[j + 1] = tmp_idx;
        }
}

static void quicksort(struct kdtree *tree, long low, long high, int r)
{
        if (high - low <= 32) {
                insert_sort(tree, low, high, r);
                //bubble_sort(tree, low, high, r);
                return;
        }

        long *indexes = tree->coord_indexes;
        /* median of 3 */
        long mid = low + (high - low) / 2;
        if (D(tree, indexes[low], r) > D(tree, indexes[mid], r)) {
                swap(indexes + low, indexes + mid);
        }
        if (D(tree, indexes[low], r) > D(tree, indexes[high], r)) {
                swap(indexes + low, indexes + high);
        }
        if (D(tree, indexes[high], r) > D(tree, indexes[mid], r)) {
                swap(indexes + high, indexes + mid);
        }

        /* D(indexes[low]) <= D(indexes[high]) <= D(indexes[mid]) */
        double pivot = D(tree, indexes[high], r);

        /* 3-way partition
         * +---------+-----------+---------+-------------+---------+
         * |  pivot  |  <=pivot  |   ?     |  >=pivot    |  pivot  |
         * +---------+-----------+---------+-------------+---------+
         * low     lt             i       j               gt    high
         */
        long i = low - 1;
        long lt = i;
        long j = high;
        long gt = j;
        for (; ;) {
                while (D(tree, indexes[++i], r) < pivot) {}
                while (D(tree, indexes[--j], r) > pivot && j > low) {}
                if (i >= j) break;
                swap(indexes + i, indexes + j);
                if (D(tree, indexes[i], r) == pivot) swap(&indexes[++lt], &indexes[i]);
                if (D(tree, indexes[j], r) == pivot) swap(&indexes[--gt], &indexes[j]);
        }
        /* i == j or j + 1 == i */
        swap(indexes + i, indexes + high);

        /* Move equal elements to the middle of array */
        long x, y;
        for (x = low, j = i - 1; x <= lt && j > lt; x++, j--) swap(indexes + x, indexes + j);
        for (y = high, i = i + 1; y >= gt && i < gt; y--, i++) swap(indexes + y, indexes + i);

        quicksort(tree, low, j - lt + x - 1, r);
        quicksort(tree, i + y - gt, high, r);
}

static struct kdnode *kdnode_alloc(double *coord, long index, int r)
{
        struct kdnode *node = malloc(sizeof(*node));
        if (node != NULL) {
                memset(node, 0, sizeof(*node));
                node->coord = coord;
                node->coord_index = index;
                node->r = r;
        }
        return node;
}

static void kdnode_free(struct kdnode *node)
{
        free(node);
}

static int coord_cmp(double *c1, double *c2, int dim)
{
        int i;
        double ret;
        for (i = 0; i < dim; i++) {
                ret = *c1++ - *c2++;
                if (fabs(ret) >= DBL_EPSILON) {
                        return ret > 0 ? 1 : -1;
                }
        }

        if (fabs(ret) < DBL_EPSILON) {
                return 0;
        } else {
                return ret > 0 ? 1 : -1;
        }
}

static void knn_list_add(struct kdtree *tree, struct kdnode *node, double distance)
{
        if (node == NULL) return;

        struct knn_list *head = &tree->knn_list_head;
        struct knn_list *p = head->prev;
        if (tree->knn_num == 1) {
                if (p->distance > distance) {
                        p = p->prev;
                }
        } else {
                while (p != head && p->distance > distance) {
                        p = p->prev;
                }
        }

        if (p == head || coord_cmp(p->node->coord, node->coord, tree->dim)) {
                struct knn_list *log = malloc(sizeof(*log));
                if (log != NULL) {
                        log->node = node;
                        log->distance = distance;
                        log->prev = p;
                        log->next = p->next;
                        p->next->prev = log;
                        p->next = log;
                        tree->knn_num++;
                }
        }
}

static void knn_list_adjust(struct kdtree *tree, struct kdnode *node, double distance)
{
        if (node == NULL) return;

        struct knn_list *head = &tree->knn_list_head;
        struct knn_list *p = head->prev;
        if (tree->knn_num == 1) {
                if (p->distance > distance) {
                        p = p->prev;
                }
        } else {
                while (p != head && p->distance > distance) {
                        p = p->prev;
                }
        }

        if (p == head || coord_cmp(p->node->coord, node->coord, tree->dim)) {
                struct knn_list *log = head->prev;
                /* Replace the original max one */
                log->node = node;
                log->distance = distance;
                /* Remove from the max position */
                head->prev = log->prev;
                log->prev->next = head;
                /* insert as a new one */
                log->prev = p;
                log->next = p->next;
                p->next->prev = log;
                p->next = log;
        }
}

static void knn_list_clear(struct kdtree *tree)
{
        struct knn_list *head = &tree->knn_list_head;
        struct knn_list *p = head->next;
        while (p != head) {
                struct knn_list *prev = p;
                p = p->next;
                free(prev);
        }
        tree->knn_num = 0;
}

static void resize(struct kdtree *tree)
{
        tree->capacity *= 2;
        tree->coords = realloc(tree->coords, tree->dim * sizeof(double) * tree->capacity);
        tree->coord_table = realloc(tree->coord_table, sizeof(double *) * tree->capacity);
        tree->coord_indexes = realloc(tree->coord_indexes, sizeof(long) * tree->capacity);
        tree->coord_deleted = realloc(tree->coord_deleted, sizeof(char) * tree->capacity);
        tree->coord_passed = realloc(tree->coord_passed, sizeof(char) * tree->capacity);
        coord_table_reset(tree);
        coord_index_reset(tree);
        coord_deleted_reset(tree);
        coord_passed_reset(tree);
}

static void kdnode_dump(struct kdnode *node, int dim)
{
        int i;
        if (node->coord != NULL) {
                printf("(");
                for (i = 0; i < dim; i++) {
                        if (i != dim - 1) {
                                printf("%.2f,", node->coord[i]);
                        } else {
                                printf("%.2f)\n", node->coord[i]);
                        }
                }
        } else {
                printf("(none)\n");
        }
}

void kdtree_insert(struct kdtree *tree, double *coord)
{
        if (tree->count + 1 > tree->capacity) {
                resize(tree);
        }
        memcpy(tree->coord_table[tree->count++], coord, tree->dim * sizeof(double));
}

static void knn_pickup(struct kdtree *tree, struct kdnode *node, double *target, int k)
{
        double dist = distance(node->coord, target, tree->dim);
        if (tree->knn_num < k) {
                knn_list_add(tree, node, dist);
        } else {
                if (dist < knn_max(tree)) {
                        knn_list_adjust(tree, node, dist);
                } else if (fabs(dist - knn_max(tree)) < DBL_EPSILON) {
                        knn_list_add(tree, node, dist);
                }
        }
}

static void kdtree_search_recursive(struct kdtree *tree, struct kdnode *node, double *target, int k, int *pickup)
{
        if (node == NULL || kdnode_passed(tree, node)) {
                return;
        }

        int r = node->r;
        if (!knn_search_on(tree, k, node->coord[r], target[r])) {
                return;
        }

        if (*pickup) {
                tree->coord_passed[node->coord_index] = 1;
                knn_pickup(tree, node, target, k);
                kdtree_search_recursive(tree, node->left, target, k, pickup);
                kdtree_search_recursive(tree, node->right, target, k, pickup);
        } else {
                if (is_leaf(node)) {
                        *pickup = 1;
                } else {
                        if (target[r] <= node->coord[r]) {
                                kdtree_search_recursive(tree, node->left, target, k, pickup);
                                kdtree_search_recursive(tree, node->right, target, k, pickup);
                        } else {
                                kdtree_search_recursive(tree, node->right, target, k, pickup);
                                kdtree_search_recursive(tree, node->left, target, k, pickup);
                        }
                }
                /* back track and pick up  */
                if (*pickup) {
                        tree->coord_passed[node->coord_index] = 1;
                        knn_pickup(tree, node, target, k);
                }
        }
}

void kdtree_knn_search(struct kdtree *tree, double *target, int k)
{
        if (k > 0) {
                int pickup = 0;
                kdtree_search_recursive(tree, tree->root, target, k, &pickup);
        }
}

void kdtree_delete(struct kdtree *tree, double *coord)
{
        int r = 0;
        struct kdnode *node = tree->root;
        struct kdnode *parent = node;

        while (node != NULL) {
                if (node->coord == NULL) {
                        if (parent->right->coord == NULL) {
                                break;
                        } else {
                                node = parent->right;
                                continue;
                        }
                }

                if (coord[r] < node->coord[r]) {
                        parent = node;
                        node = node->left;
                } else if (coord[r] > node->coord[r]) {
                        parent = node;
                        node = node->right;
                } else {
                        int ret = coord_cmp(coord, node->coord, tree->dim);
                        if (ret < 0) {
                                parent = node;
                                node = node->left;
                        } else if (ret > 0) {
                                parent = node;
                                node = node->right;
                        } else {
                                node->coord = NULL;
                                break;
                        }
                }
                r = (r + 1) % tree->dim;
        }
}

static void kdnode_build(struct kdtree *tree, struct kdnode **nptr, int r, long low, long high)
{
        if (low == high) {
                long index = tree->coord_indexes[low];
                *nptr = kdnode_alloc(tree->coord_table[index], index, r);
        } else if (low < high) {
                /* Sort and fetch the median to build a balanced BST */
                quicksort(tree, low, high, r);
                long median = low + (high - low) / 2;
                long median_index = tree->coord_indexes[median];
                struct kdnode *node = *nptr = kdnode_alloc(tree->coord_table[median_index], median_index, r);
                r = (r + 1) % tree->dim;
                kdnode_build(tree, &node->left, r, low, median - 1);
                kdnode_build(tree, &node->right, r, median + 1, high);
        }
}

static void kdtree_build(struct kdtree *tree)
{
        kdnode_build(tree, &tree->root, 0, 0, tree->count - 1);
}

void kdtree_rebuild(struct kdtree *tree)
{
        long i, j;
        size_t size_of_coord = tree->dim * sizeof(double);
        for (i = 0, j = 0; j < tree->count; i++, j++) {
                while (j < tree->count && tree->coord_deleted[j]) {
                        j++;
                }
                if (i != j && j < tree->count) {
                        memcpy(tree->coord_table[i], tree->coord_table[j], size_of_coord);
                        tree->coord_deleted[i] = 0;
                }
        }
        tree->count = i;
        coord_index_reset(tree);
        kdtree_build(tree);
}

struct kdtree *kdtree_init(int dim)
{
        struct kdtree *tree = malloc(sizeof(*tree));
        if (tree != NULL) {
                tree->root = NULL;
                tree->dim = dim;
                tree->count = 0;
                tree->capacity = 65536;
                tree->knn_list_head.next = &tree->knn_list_head;
                tree->knn_list_head.prev = &tree->knn_list_head;
                tree->knn_list_head.node = NULL;
                tree->knn_list_head.distance = 0;
                tree->knn_num = 0;
                tree->coords = malloc(dim * sizeof(double) * tree->capacity);
                tree->coord_table = malloc(sizeof(double *) * tree->capacity);
                tree->coord_indexes = malloc(sizeof(long) * tree->capacity);
                tree->coord_deleted = malloc(sizeof(char) * tree->capacity);
                tree->coord_passed = malloc(sizeof(char) * tree->capacity);
                coord_index_reset(tree);
                coord_table_reset(tree);
                coord_deleted_reset(tree);
                coord_passed_reset(tree);
        }
        return tree;
}

static void kdnode_destroy(struct kdnode *node)
{
        if (node == NULL) return;
        kdnode_destroy(node->left);
        kdnode_destroy(node->right);
        kdnode_free(node);
}

void kdtree_destroy(struct kdtree *tree)
{
        kdnode_destroy(tree->root);
        knn_list_clear(tree);
        free(tree->coords);
        free(tree->coord_table);
        free(tree->coord_indexes);
        free(tree->coord_deleted);
        free(tree->coord_passed);
        free(tree);
}

#define _KDTREE_DEBUG

#ifdef _KDTREE_DEBUG
struct kdnode_backlog {
        struct kdnode *node;
        int next_sub_idx;
};

void kdtree_dump(struct kdtree *tree)
{
        int level = 0;
        struct kdnode *node = tree->root;
        struct kdnode_backlog nbl, *p_nbl = NULL;
        struct kdnode_backlog nbl_stack[KDTREE_MAX_LEVEL];
        struct kdnode_backlog *top = nbl_stack;

        for (; ;) {
                if (node != NULL) {
                        /* Fetch the pop-up backlogged node's sub-id.
                         * If not backlogged, fetch the first sub-id. */
                        int sub_idx = p_nbl != NULL ? p_nbl->next_sub_idx : KDTREE_RIGHT_INDEX;

                        /* Backlog should be left in next loop */
                        p_nbl = NULL;

                        /* Backlog the node */
                        if (is_leaf(node) || sub_idx == KDTREE_LEFT_INDEX) {
                                top->node = NULL;
                                top->next_sub_idx = KDTREE_RIGHT_INDEX;
                        } else {
                                top->node = node;
                                top->next_sub_idx = KDTREE_LEFT_INDEX;
                        }
                        top++;
                        level++;

                        /* Draw lines as long as sub_idx is the first one */
                        if (sub_idx == KDTREE_RIGHT_INDEX) {
                                int i;
                                for (i = 1; i < level; i++) {
                                        if (i == level - 1) {
                                                printf("%-8s", "+-------");
                                        } else {
                                                if (nbl_stack[i - 1].node != NULL) {
                                                        printf("%-8s", "|");
                                                } else {
                                                        printf("%-8s", " ");
                                                }
                                        }
                                }
                                kdnode_dump(node, tree->dim);
                        }

                        /* Move down according to sub_idx */
                        node = sub_idx == KDTREE_LEFT_INDEX ? node->left : node->right;
                } else {
                        p_nbl = top == nbl_stack ? NULL : --top;
                        if (p_nbl == NULL) {
                                /* End of traversal */
                                break;
                        }
                        node = p_nbl->node;
                        level--;
                }
        }
}
#endif

python

class kdtree(object):
    
    # 创建 kdtree
    # point_list 是一个 list 的 pair,pair[0] 是一 tuple 的特征,pair[1] 是类别
    def __init__(self, point_list, depth=0, root=None):
        
        if len(point_list)>0:
            
            # 轮换按照树深度选择坐标轴
            k = len(point_list[0][0])
            axis = depth % k
            
            # 选中位线,切
            point_list.sort(key=lambda x:x[0][axis])
            median = len(point_list) // 2
            
            self.axis = axis
            self.root = root
            self.size = len(point_list)
            
            # 造节点
            self.node = point_list[median]
            # 递归造左枝和右枝
            if len(point_list[:median])>0:
                self.left = kdtree(point_list[:median], depth+1, self)
            else:
                self.left = None
            if len(point_list[median+1:])>0:
                self.right = kdtree(point_list[median+1:], depth+1, self)
            else:
                self.right = None
            # 记录是按哪个方向切的还有树根

        else:
            return None
    
    # 在树上加一点
    def insert(self, point):
        self.size += 1
        
        # 分析是左还是右,递归加在叶子上
        if point[0][self.axis]<self.node[0][self.axis]:
            if self.left!=None:
                self.left.insert(point)
            else:
                self.left = kdtree([point], self.axis+1, self)
        else:
            if self.right!=None:
                self.right.insert(point)
            else:
                self.right = kdtree([point], self.axis+1, self)
            
            
    # 输入一点
    # 按切分寻找叶子
    def find_leaf(self, point):
        if self.left==None and self.right==None:
            return self
        elif self.left==None:
            return self.right.find_leaf(point)
        elif self.right==None:
            return self.left.find_leaf(point)
        elif point[self.axis]<self.node[0][self.axis]:
            return self.left.find_leaf(point)
        else:
            return self.right.find_leaf(point)
        

    # 查找最近的 k 个点,复杂度 O(DlogN),D是维度,N是树的大小
    # 输入一点、一距离函数、一k。距离函数默认是 L_2
    def knearest(self, point, k=1, dist=lambda x,y: sum(map(lambda u,v:(u-v)**2,x,y))):
        # 往下戳到最底叶
        leaf = self.find_leaf(point)
        # 从叶子网上爬
        return leaf.k_down_up(point, k, dist, result=[], stop=self, visited=None)


    # 从下往上爬函数,stop是到哪里去,visited是从哪里来
    def k_down_up(self, point,k, dist, result=[],stop=None, visited=None):

        # 选最长距离
        if result==[]:
            max_dist = 0
        else:
            max_dist = max([x[1] for x in result])

        other_result=[]

        # 如果离分界线的距离小于现有最大距离,或者数据点不够,就从另一边的树根开始刨
        if (self.left==visited and self.node[0][self.axis]-point[self.axis]<max_dist and self.right!=None)\
            or (len(result)<k and self.left==visited and self.right!=None):
            other_result=self.right.knearest(point,k, dist)

        if (self.right==visited and point[self.axis]-self.node[0][self.axis]<max_dist and self.left!=None)\
            or (len(result)<k and self.right==visited and self.left!=None):
            other_result=self.left.knearest(point, k, dist)

        # 刨出来的点放一起,选前 k 个
        result.append((self.node, dist(point, self.node[0])))
        result = sorted(result+other_result, key=lambda pair: pair[1])[:k]

        # 到停点就返回结果
        if self==stop:
            return result
        # 没有就带着现有结果接着往上爬
        else:
            return self.root.k_down_up(point,k,  dist, result, stop, self)

    # 输入 特征、类别、k、距离函数
    # 返回这个点属于该类别的概率
    def kNN_prob(self, point, label, k, dist=lambda x,y: sum(map(lambda u,v:(u-v)**2,x,y))):
        nearests = self.knearest(point,  k, dist)
        return float(len([pair for pair in nearests if pair[0][1]==label]))/float(len(nearests))


    # 输入 特征、k、距离函数
    # 返回该点概率最大的类别以及相对应的概率
    def kNN(self, point, k, dist=lambda x,y: sum(map(lambda u,v:(u-v)**2,x,y))):
        nearests = self.knearest(point, k , dist)

        statistics = {}
        for data in nearests:
            label = data[0][1]
            if label not in statistics: 
                statistics[label] = 1
            else:
                statistics[label] += 1

        max_label = max(statistics.iteritems(), key=operator.itemgetter(1))[0]
        return max_label, float(statistics[max_label])/float(len(nearests))

参考自:JoinQuant量化课堂

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浩波的笔记

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值