一维KD树(貌似用不上,但是理解一维有助于理解二维)
Code
#include <bits/stdc++.h>
#define MAX 500005
#define NIL -1
using namespace std;
#define for(x, y) for(int i=x; i<y; i++)
struct Node1D { int loc, l, r; };
Node1D T[MAX];
int n, np;
int P[MAX];
int make1D(int l, int r)
{
if (l >= r)
return NIL;
int mid = (l + r) / 2;
int t = np++;
T[t].loc = mid;
T[t].l = make1D(l, mid);
T[t].r = make1D(mid + 1, r);
return t;
}
void find1D(int p, int sx, int tx)
{
int x = P[T[p].loc];
if (x >= sx && x <= tx)
cout << x << " ";
if (T[p].l != NIL && x >= sx)
find1D(T[p].l, sx, tx);
if (T[p].r != NIL && x <= tx)
find1D(T[p].r, sx, tx);
}
int main()
{
cin >> n;
for (0, n)
cin >> P[i];
sort(P, P + n);
make1D(0, n);
int q;
cin >> q;
while (q--) {
int x, y;
cin >> x >> y;
find1D(0, x, y);
cout << endl;
}
return 0;
}
Data
10
1 2 2 2 3 4 4 5 6 7
5
2 3
2 4
1 5
2 6
4 2
二维KD树(代板子过的,现在理解还很困难)
#include <bits/stdc++.h>
#define MAX 500005
#define NIL -1
using namespace std;
#define sd(n) scanf("%d", &n)
#define pd(n) printf("%d\n", n)
#define ps(s) printf(s)
#define for(x, y) for(int i=x; i<y; i++)
typedef pair<int, pair<int, int> > Pair;
struct Node2D { int loc, l, r; };
Node2D T[MAX];
int n, np;
Pair P[MAX];
vector<Pair> ans;
bool cmp1(const Pair &p1, const Pair &p2)
{
return p1.second.first < p2.second.first;
}
bool cmp2(const Pair &p1, const Pair &p2)
{
return p1.second.second < p2.second.second;
}
bool cmp3(const Pair &p1, const Pair &p2)
{
return p1.first < p2.first;
}
void init()
{
np = 0;
}
void print()
{
for (0, ans.size())
pd(ans[i].first);
}
int make2D(int l, int r, int d)
{
if (l >= r)
return NIL;
int mid = (l + r) / 2;
int t = np++;
if (d % 2 == 0) //x轴为基准
sort(P + l, P + r, cmp1);
else //y轴为基准
sort(P + l, P + r, cmp2);
T[t].loc = mid;
T[t].l = make2D(l, mid, d + 1);
T[t].r = make2D(mid + 1, r, d + 1);
return t;
}
void find2D(int p, int sx, int tx, int sy, int ty, int d)
{
int x = P[T[p].loc].second.first;
int y = P[T[p].loc].second.second;
if (sx <= x && x <= tx && sy <= y && y <= ty)
ans.push_back(P[T[p].loc]);
if (d % 2 == 0) //x轴为基准
{
if (T[p].l != NIL && sx <= x)
find2D(T[p].l, sx, tx, sy, ty, d + 1);
if (T[p].r != NIL && x <= tx)
find2D(T[p].r, sx, tx, sy, ty, d + 1);
}
else //y轴为基准
{
if (T[p].l != NIL && sy <= y)
find2D(T[p].l, sx, tx, sy, ty, d + 1);
if (T[p].r != NIL && y <= ty)
find2D(T[p].r, sx, tx, sy, ty, d + 1);
}
}
int main()
{
sd(n);
init();
for (0, n) {
sd(P[i].second.first);
sd(P[i].second.second);
P[i].first = i;
}
int root = make2D(0, n, 0);
int q;
sd(q);
while (q--) {
int sx, tx, sy, ty;
sd(sx); sd(tx); sd(sy); sd(ty);
ans.clear();
find2D(root, sx, tx, sy, ty, 0);
sort(ans.begin(), ans.end(), cmp3);
print();
ps("\n");
}
return 0;
}