题意: 给你n个k维的点, 然后q个询问, 对于每个询问会给出一个点,需要我们求出离这个点最近的m个点,按距离从小到大排序。
思路: :和二维类似, 但求的是最近的m个点, 就要结合优先队列了,在遇到一个点时,如果优先队列还没有满,就先加入,如果满了, 就和队列头部的元素相比较,看谁最优,在判断是否继续分支的时候也要根据队列的情况,来进一步的判断是否继续向下搜索。
补充知识:pait内部自定义了排序优先级, 第一排序的时first, 然后是second
具体细节见于代码
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 1e5 + 50;
int split[20], cur, dim;
struct Point // 点的坐标
{
ll x[5]; // 多维
bool operator <(const Point &b) const
{
return x[cur] < b.x[cur];
}
}p[N], ori[N];
int n, k, m1, q;
priority_queue<pair<ll, Point> > que;
// cur 当前划分的是是第几维 dim 总共是第几维
bool cmp(const Point &a, const Point &b)
{
return a.x[cur] < b.x[cur];
}
#define lson l, m - 1, depth + 1
#define rson m + 1, r, depth + 1
// 模板函数
template <class T> T sqr(T x) {return x * x;}
const ll inf = 0x7777777777777777ll;
ll dist(const Point &x, const Point &y) // 返回两个点的距离
{
ll ret = 0;
for(int i = 0; i < dim; i ++)
{
ret += sqr(x.x[i] - y.x[i]);
}
return ret;
}
void build(const int &l, const int &r, const int &depth) // 建树
{
if(l >= r) return;
int m = l + r >> 1;
cur = depth % dim;
nth_element(p + l, p + m, p + r + 1);
build(lson);
build(rson);
}
// 查找离x最近的点
ll Find(const Point &x, const int &l, const &r, const int &depth)
{
int cur = depth % dim;//第几维
if(l > r)
{
return inf;
}
int m = l + r >> 1;
ll ret = dist(x, p[m]), tmp; // ret记录与x的距离
while(que.size() >= m1 && que.top().first > ret) que.pop();
if(que.size() < m1)
{
que.push({ret, p[m]});
}
if(x.x[cur] < p[m].x[cur])
{
/*如果在当前维度下, x的cur维坐标小于p的那么就查询它的左边
*/
tmp = Find(x, lson);
if(que.size() < m1 || que.top().first > sqr(x.x[cur] - p[m].x[cur]))
// 如果当前队列元素还没满或者队列中的元素可能不是最优
tmp = min(tmp, Find(x, rson));
}
else
{
tmp = Find(x, rson);//同理
if(que.size() < m1 || que.top().first > sqr(x.x[cur] - p[m].x[cur]))
{
tmp = min(tmp, Find(x, lson));
}
}
return min(ret, tmp);
}
int main()
{
while(~scanf("%d%d", &n, &k))
{
dim = k;
for(int i = 0; i < n; i++)
{
for(int j = 0; j < k; j ++)
{
scanf("%lld", &ori[i].x[j]);
}
p[i] = ori[i]; // 保存点集
}
build(0, n - 1, 0);//建立KD_tree,进行划分
scanf("%d", &q);
while(q --)
{
Point now;
for(int i = 0; i < k; i++)
{
scanf("%lld", &now.x[i]);
}
scanf("%d", &m1);
Find(now, 0, n - 1, 0);
int t = 0;
Point pp[21];
while(!que.empty())
{
pp[++t] = que.top().second;
que.pop();
}
printf("the closest %d points are:\n", t);
for(int i = m1; i > 0; i --)
{
printf("%lld", pp[i].x[0]);
for(int j = 1; j < k; j ++)
{
printf(" %lld", pp[i].x[j]);
}
puts("");
}
}
}
return 0;
}
’