一道KNN的题。直接用kd树加上一个暴力更新就撸过去了。写的时候有一个错误就是搜索一边子树的时候返回有当前层数会被改变了,然后就直接判断搜索另一边子树,搞到wa了半天。
代码如下:
1 #include <cstdio> 2 #include <iostream> 3 #include <cstring> 4 #include <algorithm> 5 #include <vector> 6 7 using namespace std; 8 9 const int K = 6; 10 const int N = 55555; 11 int dm, cdm; 12 template<class T> T sqr(T x) { return x * x;} 13 struct Node { 14 int x[K]; 15 Node *c[2]; 16 bool operator < (Node a) const { 17 for (int i = 0; i < dm; i++) if (x[(cdm + i) % dm] != a.x[(cdm + i) % dm]) return x[(cdm + i) % dm] < a.x[(cdm + i) % dm]; 18 return true; 19 } 20 } node[N]; 21 22 int dist(Node *a, Node *b) { 23 int ret = 0; 24 for (int i = 0; i < dm; i++) ret += sqr(a->x[i] - b->x[i]); 25 return ret; 26 } 27 28 struct KDT { 29 Node *knn[15]; 30 int top, dis[15], sz; 31 Node *RT; 32 void build(int l, int r, int dp, Node *&rt) { 33 cdm = dp % dm; 34 if (l > r) { rt = NULL; return ;} 35 int m = l + r >> 1; 36 nth_element(node + l, node + m, node + r + 1); 37 rt = node + m; 38 build(l, m - 1, dp + 1, rt->c[0]); 39 build(m + 1, r, dp + 1, rt->c[1]); 40 } 41 void build(int l, int r) { 42 sz = r - l + 1; 43 build(l, r, 0, RT); 44 } 45 void search(int dp, Node *x, Node *rt, int k) { 46 if (!rt) return ; 47 cdm = dp % dm; 48 int d = dist(x, rt), p = top; 49 while (p >= 0 && d < dis[p]) dis[p + 1] = dis[p], knn[p + 1] = knn[p], p--; 50 p++; 51 dis[p] = d; 52 knn[p] = rt; 53 if (top + 1 < k) top++; 54 bool r = x->x[cdm] >= rt->x[cdm]; 55 search(dp + 1, x, rt->c[r], k); 56 cdm = dp % dm; 57 if (top + 1 < k || sqr(rt->x[cdm] - x->x[cdm]) < dis[top]) search(dp + 1, x, rt->c[!r], k); 58 } 59 void search(Node *x, int k) { 60 top = -1; 61 search(0, x, RT, k); 62 } 63 void pre(Node *x) { 64 if (!x) return ; 65 pre(x->c[0]); 66 cout << x->x[0] << ' ' << x->x[1] << endl; 67 pre(x->c[1]); 68 } 69 } kdt; 70 71 int main() { 72 int n, m, k; 73 Node tmp; 74 while (~scanf("%d%d", &n, &dm)) { 75 for (int i = 0; i < n; i++) { 76 for (int j = 0; j < dm; j++) scanf("%d", node[i].x + j); 77 node[i].c[0] = node[i].c[1] = NULL; 78 } 79 kdt.build(0, n - 1); 80 scanf("%d", &m); 81 while (m--) { 82 for (int i = 0; i < dm; i++) scanf("%d", tmp.x + i); 83 scanf("%d", &k); 84 kdt.search(&tmp, k); 85 printf("the closest %d points are:\n", k); 86 for (int i = 0; i <= kdt.top; i++) { 87 for (int j = 0; j < dm; j++) { 88 if (j) putchar(' '); 89 printf("%d", kdt.knn[i]->x[j]); 90 } 91 puts(""); 92 } 93 } 94 } 95 return 0; 96 }
——written by Lyon