Kd树寻找距离P最近的K个点(C++ 代码实现)

推荐学习该知识文章

实习时划水手动实现代码:

#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <queue>
#include <unordered_map>
#include <vector>
using namespace std;
struct KdNode {
  KdNode* Parent;
  KdNode* LSon;
  KdNode* RSon;
  bool div;
  float Axis, Axis_x, Axis_y;
  KdNode() { Parent = nullptr, LSon = nullptr, RSon = nullptr; }
  KdNode(KdNode* Par, KdNode* Ls, KdNode* Rs, float axis, float axis_x,
         float axis_y)
      : Parent(Par),
        LSon(Ls),
        RSon(Rs),
        Axis(axis),
        Axis_x(axis_x),
        Axis_y(axis_y) {}
};
unordered_map<KdNode*, bool> vis;
priority_queue<pair<float, KdNode*>, vector<pair<float, KdNode*> > > q;
struct Point {
  float Posx, Posy;  //坐标轴x
} KdAxis[1005];

bool cmpx(Point Pxo, Point Pxt) { return Pxo.Posx < Pxt.Posx; }

bool cmpy(Point Pxo, Point Pxt) { return Pxo.Posy < Pxt.Posy; }

inline float GetDis(float Kdx, float Kdy, float Px, float Py) {
  return (Kdx - Px) * (Kdx - Px) + (Kdy - Py) * (Kdy - Py);
}

void BuildKdTree(KdNode* Rt, KdNode* Fa, int L, int R, int div) {
  // cout << L << '*' << R << endl;
  Rt->Parent = Fa;
  Rt->div = div;
  if (!div) {
    sort(KdAxis + L, KdAxis + R + 1, cmpx);
  } else {
    sort(KdAxis + L, KdAxis + R + 1, cmpy);
  }

  div = (div + 1) % 2;
  int MID = (L + R) / 2;
  KdNode* Ls = new KdNode();
  KdNode* Rs = new KdNode();

  Rt->Axis = div == 0 ? KdAxis[MID].Posx : KdAxis[MID].Posy;
  Rt->Axis_x = KdAxis[MID].Posx;
  Rt->Axis_y = KdAxis[MID].Posy;
  if (L <= MID - 1) {
    Rt->LSon = Ls;
  }
  if (MID + 1 <= R) {
    Rt->RSon = Rs;
  }
  if (L <= MID - 1) {
    BuildKdTree(Ls, Rt, L, MID - 1, div);
  }
  if (MID + 1 <= R) {
    BuildKdTree(Rs, Rt, MID + 1, R, div);
  }
}

void dfs(KdNode* Rt, float px, float py, int k) {
  if (Rt == nullptr) {
    return;
  }
  if (!vis[Rt]) {
    if (Rt->div == 0) {
      if (px <= Rt->Axis_x) {
        dfs(Rt->LSon, px, py, k);

        if (!vis[Rt]) {
          if (q.size() < k) {
            q.push(make_pair(GetDis(Rt->Axis_x, Rt->Axis_y, px, py), Rt));
          } else {
            if (GetDis(Rt->Axis_x, Rt->Axis_y, px, py) < q.top().first) {
              q.pop();
              q.push(make_pair(GetDis(Rt->Axis_x, Rt->Axis_y, px, py), Rt));
            }
          }
        }

        vis[Rt] = true;
        if (q.top().first > GetDis(Rt->Axis, 0, px, py) || q.size() < k) {
          dfs(Rt->RSon, px, py, k);
        }

      } else {
        dfs(Rt->RSon, px, py, k);

        if (!vis[Rt]) {
          if (q.size() < k) {
            q.push(make_pair(GetDis(Rt->Axis_x, Rt->Axis_y, px, py), Rt));
          } else {
            if (GetDis(Rt->Axis_x, Rt->Axis_y, px, py) < q.top().first) {
              q.pop();
              q.push(make_pair(GetDis(Rt->Axis_x, Rt->Axis_y, px, py), Rt));
            }
          }
        }

        vis[Rt] = true;
        if (q.top().first > GetDis(Rt->Axis, 0, px, py) || q.size() < k) {
          dfs(Rt->LSon, px, py, k);
        }
      }
    } else {
      if (py <= Rt->Axis_y) {
        dfs(Rt->LSon, px, py, k);

        if (!vis[Rt]) {
          if (q.size() < k) {
            q.push(make_pair(GetDis(Rt->Axis_x, Rt->Axis_y, px, py), Rt));
          } else {
            if (GetDis(Rt->Axis_x, Rt->Axis_y, px, py) < q.top().first) {
              q.pop();
              q.push(make_pair(GetDis(Rt->Axis_x, Rt->Axis_y, px, py), Rt));
            }
          }
        }

        vis[Rt] = true;
        if (q.top().first > GetDis(0, Rt->Axis, px, py) || q.size() < k) {
          dfs(Rt->RSon, px, py, k);
        }
      } else {
        dfs(Rt->RSon, px, py, k);

        if (!vis[Rt]) {
          if (q.size() < k) {
            q.push(make_pair(GetDis(Rt->Axis_x, Rt->Axis_y, px, py), Rt));
          } else {
            if (GetDis(Rt->Axis_x, Rt->Axis_y, px, py) < q.top().first) {
              q.pop();
              q.push(make_pair(GetDis(Rt->Axis_x, Rt->Axis_y, px, py), Rt));
            }
          }
        }

        vis[Rt] = true;
        if (q.top().first > GetDis(0, Rt->Axis, px, py) || q.size() < k) {
          dfs(Rt->LSon, px, py, k);
        }
      }
    }
  }
}

int main() {
  KdAxis[1].Posx = 6.27, KdAxis[1].Posy = 5.50;
  KdAxis[2].Posx = 1.24, KdAxis[2].Posy = -2.86;
  KdAxis[3].Posx = -6.88, KdAxis[3].Posy = -5.40;
  KdAxis[4].Posx = -2.96, KdAxis[4].Posy = -2.50;
  KdAxis[5].Posx = -4.60, KdAxis[5].Posy = -10.55;
  KdAxis[6].Posx = -4.96, KdAxis[6].Posy = 12.61;
  KdAxis[7].Posx = 1.75, KdAxis[7].Posy = 12.26;
  KdAxis[8].Posx = 17.05, KdAxis[8].Posy = -12.79;
  KdAxis[9].Posx = 7.75, KdAxis[9].Posy = -22.68;
  KdAxis[10].Posx = 15.31, KdAxis[10].Posy = -13.16;
  KdAxis[11].Posx = 10.80, KdAxis[11].Posy = -5.03;
  KdAxis[12].Posx = 7.83, KdAxis[12].Posy = 15.70;
  KdAxis[13].Posx = 14.63, KdAxis[13].Posy = -0.35;

  KdNode* Rt = new KdNode();
  int div = 0, L = 1, R = 13;
  BuildKdTree(Rt, nullptr, L, R, div);

  dfs(Rt, -1, -5, 5);
  while (!q.empty()) {
    KdNode* NowNode = q.top().second;
    cout << NowNode->Axis_x << ' ' << NowNode->Axis_y << endl;
    q.pop();
  }
  return 0;
}

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
要使用kd树找到距离p点最近的n个点,可以使用以下步骤: 1. 构建kd树,将所有点插入到kd树中。 2. 从根节点开始,递归遍历kd树。在每个节点,根据当前节点的分裂维度和分裂值,将查询点p与节点的距离计算出来。 3. 如果当前节点的距离比当前找到的n个最近点中最远的点还要远,那么可以直接返回,不再继续递归。 4. 否则,将当前节点加入到候选最近点集合中,如果候选最近点集合的大小大于n,则需要删除最远的点。 5. 继续递归遍历当前节点的子树,直到遍历完所有的子树。 6. 最后,候选最近点集合中的所有点就是距离p点最近的n个点。 下面是一个使用kd树查找距离p点最近的n个点的Python实现: ```python import heapq import numpy as np class KDTree: def __init__(self, data): self.data = data self.n = data.shape[0] self.k = data.shape[1] self.tree = self.build_tree(0, self.n, 0) def build_tree(self, left, right, depth): if left >= right: return None mid = (left + right) // 2 axis = depth % self.k sorted_idx = np.argsort(self.data[:, axis]) left_idx = np.where(sorted_idx[:mid] == sorted_idx[mid])[0] right_idx = np.where(sorted_idx[mid + 1:] == sorted_idx[mid])[0] + mid + 1 node = { "idx": sorted_idx[mid], "axis": axis, "left": self.build_tree(left + len(left_idx), right - len(right_idx), depth + 1), "right": self.build_tree(left, left + len(left_idx), depth + 1) } return node def search_knn(self, p, n): candidate = [] heapq.heapify(candidate) self.search(self.tree, p, n, candidate) return [self.data[idx] for _, idx in heapq.nsmallest(n, candidate)] def search(self, node, p, n, candidate): if node is None: return dist = np.linalg.norm(p - self.data[node["idx"]]) if len(candidate) < n or dist < -candidate[0][0]: heapq.heappush(candidate, (-dist, node["idx"])) if len(candidate) > n: heapq.heappop(candidate) axis = node["axis"] if p[axis] < self.data[node["idx"], axis]: self.search(node["left"], p, n, candidate) else: self.search(node["right"], p, n, candidate) if len(candidate) < n or abs(p[axis] - self.data[node["idx"], axis]) < -candidate[0][0]: if p[axis] < self.data[node["idx"], axis]: self.search(node["right"], p, n, candidate) else: self.search(node["left"], p, n, candidate) ``` 上面的代码中,`KDTree`类用于构建kd树查找最近的n个点。`build_tree`方法用于递归构建kd树。`search_knn`方法用于查找距离p点最近的n个点,其中使用了一个小根堆来维护当前找到的最近的n个点。`search`方法用于递归查找最近的n个点
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值