暂时wrong answer,求各路大神指错。。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
struct POINT
{
int p[5];
int &operator[](int x)
{
return p[x];
}
const int &operator[](int x) const
{
return p[x];
}
}point[500005];
bool cmp0(const POINT &a, const POINT &b)
{
return a[0] < b[0];
}
bool cmp1(const POINT &a, const POINT &b)
{
return a[1] < b[1];
}
bool cmp2(const POINT &a, const POINT &b)
{
return a[2] < b[2];
}
bool cmp3(const POINT &a, const POINT &b)
{
return a[3] < b[3];
}
bool cmp4(const POINT &a, const POINT &b)
{
return a[4] < b[4];
}
bool cmp5(const POINT &a, const POINT &b)
{
return a[5] < b[5];
}
bool cmp6(const POINT &a, const POINT &b)
{
return a[6] < b[6];
}
bool cmp7(const POINT &a, const POINT &b)
{
return a[7] < b[7];
}
bool cmp8(const POINT &a, const POINT &b)
{
return a[8] < b[8];
}
bool cmp9(const POINT &a, const POINT &b)
{
return a[9] < b[9];
}
int n, k;
int q, m;
bool (*cmp[10])(const POINT &a, const POINT &b);
int tree[1000010];
int splitAxis[500005];
POINT knn[20];
POINT p;
int knnSize;
long long distance2(const POINT &a, const POINT &b)
{
long long total = 0;
for (int i = 0; i < k; i++)
total += (b[i] - a[i]) * (b[i] - a[i]);
return total;
}
bool knnCmp(const POINT &a, const POINT &b)
{
return distance2(a, p) < distance2(b, p);
}
void addPoint(const POINT &crtPoint)
{
if (knnSize < m)
{
knn[knnSize++] = crtPoint;
sort(knn, knn + knnSize, knnCmp);
return;
}
if (distance2(crtPoint, p) < distance2(knn[knnSize - 1], p))
knn[knnSize - 1] = crtPoint;
sort(knn, knn + knnSize, knnCmp);
}
void BuildKDTree(int l, int r, int axis, int root)
{
if (r < l) return;
sort(point + l, point + r, cmp[axis]);
int mid = (l + r) / 2;
// while (mid + 1 <= r && point[mid][axis] == point[mid + 1][axis])
// mid++;
tree[root] = mid;
splitAxis[mid] = axis;
axis = (axis + 1) % k;
BuildKDTree(l, mid - 1, axis, root * 2 + 1);
BuildKDTree(mid + 1, r, axis, root * 2 + 2);
}
void TestKDTree()
{
for (int i = 0; i < n; i++)
{
int index = tree[i];
for (int j = 0; j < k; j++)
printf("%d ", point[index][j]);
printf("\n");
}
}
void findKnn(int root)
{
if (tree[root] == -1)
return;
int index = tree[root];
addPoint(point[index]);
int axis = splitAxis[index];
if (p[axis] <= point[index][axis])
{
findKnn(root * 2 + 1);
if (knnSize < m)
findKnn(root * 2 + 2);
else
{
long long disToSplitPlane = (point[index][axis] - p[axis]) * (point[index][axis] - p[axis]);
long long radius = distance2(knn[knnSize - 1], p);
if (radius >= disToSplitPlane)
findKnn(root * 2 + 2);
}
}
else
{
findKnn(root * 2 + 2);
if (knnSize < m)
findKnn(root * 2 + 1);
else
{
long long disToSplitPlane = (point[index][axis] - p[axis]) * (point[index][axis] - p[axis]);
long long radius = distance2(knn[knnSize - 1], p);
if (radius >= disToSplitPlane)
findKnn(root * 2 + 1);
}
}
}
int main()
{
cmp[0] = cmp0;
cmp[1] = cmp1;
cmp[2] = cmp2;
cmp[3] = cmp3;
cmp[4] = cmp4;
cmp[5] = cmp5;
cmp[6] = cmp6;
cmp[7] = cmp7;
cmp[8] = cmp8;
cmp[9] = cmp9;
while (scanf("%d", &n) != EOF)
{
scanf("%d", &k);
memset(point, 0, sizeof(point));
memset(tree, -1, sizeof(tree));
for (int i = 0; i < n; i++)
{
for (int j = 0; j < k; j++)
scanf("%d", &point[i][j]);
}
BuildKDTree(0, n - 1, 0, 0);
//TestKDTree();
scanf("%d", &q);
for (int i = 0; i < q; i++)
{
for (int j = 0; j < k; j++)
scanf("%d", &p[j]);
scanf("%d", &m);
knnSize = 0;
findKnn(0);
printf("the closest %d points are:\n", m);
for (int j = 0; j < m; j++)
{
for (int l = 0; l < k - 1; l++)
printf("%d ", knn[j][l]);
printf("%d\n", knn[j][k - 1]);
}
}
}
}