CDQ 分治的思想最早由 IOI2008 金牌得主陈丹琦在高中时整理并总结,目前这个思想的拓展十分广泛。
- 优点:可以将数据结构或者 DP 优化掉一维
- 缺点:这是离线算法。
引入
让我们来看一个问题
有 nn 个元素,第 ii 个元素有 ai,bi,ciai,bi,ci 三个属性,设 f(i)f(i) 表示满足 aj≤aiaj≤ai 且 bj≤bibj≤bi 且 cj≤cicj≤ci 且 j≠ij≠i 的 jj 的数量。
对于 d∈[0,n)d∈[0,n),求 f(i)=df(i)=d 的数量。
1≤n≤1051≤n≤105,1≤ai,bi,ci≤k≤2×1051≤ai,bi,ci≤k≤2×105。
这是一个三维偏序问题。
偏序问题:给定序列 AA,其中有序对 (Ai,Aj)(Ai,Aj),满足 i<ji<j 且 Ai<AjAi<Aj 这样的有序对我们称之为逆序对, 信息学竞赛中的逆序对问题,一般是要我们计数给出序列的逆序对个数的总和。其实可以把它看成一个特殊的二维偏序问题,或者说是离散化 xx 坐标的二维偏序问题。
而 CDQ 分治,可以来解决三维偏序问题。
上面的引入问题就是模板题 P3810 【模板】三维偏序(陌上花开) 的题意。
P3810 【模板】三维偏序(陌上花开)
变量及其含义
struct node { int x, y, z, cnt, ans; } s1[N], s2[N];
x, y, z
: 三个元素。cnt
:相同元素的个数。ans
:统计答案。
对于第一维 aa,我们可以先从小到大 sort
一遍,ii 号点前面的点的 aa 都比 aiai 小,这样我们就减少了一维的处理,还剩下两维。
bool cmp1(node a, node b) { if (a.x == b.x) { if (a.y == b.y) { return a.z < b.z; } else return a.y < b.y; } return a.x < b.x; } // main() 函数里面 n = read<int>(), k = read<int>(); mx = k; for (int i = 1, x, y, z; i <= n; ++ i) { x = read<int>(), y = read<int>(), z = read<int>(); s1[i].x = x, s1[i].y = y, s1[i].z = z; } sort(s1 + 1, s1 + n + 1, cmp1);
排完序后,我们可以将相同的元素合并为一个元素,结构体里的 cnt
就派上用场了。
int top = 0; for (int i = 1; i <= n; ++ i) { ++ top; if (s1[i].x != s1[i + 1].x || s1[i].y != s1[i + 1].y || s1[i].z != s1[i + 1].z) { s2[++ m].x = s1[i].x; s2[m].y = s1[i].y; s2[m].z = s1[i].z; s2[m].cnt = top; top = 0; } }
然后处理第二维,对于第二维,我们要求 bj≤bibj≤bi,按照前面的思路,我们肯定也要想方设法给第二维排序。
我们可以用 归并排序 的思想,先分别给左半个区间和右半个区间按照第二维从小到大排序,然后依次处理,由于是在 aa 排好序的基础上进行的在排序,且这两个的区间还没有合并,所以无论怎么打乱,都可以保证左半边元素的 aa 小于等于右半边元素的 aa。
对于第三维,相当于到了我们找逆序对的环节了,我们有归并排序和树状数组两种方法,但由于归并排序已经放到前面去处理第二维了,所以我们用树状数组来处理第三维,将节点依次插入树状数组,统计。
bool cmp2(node a, node b) { if (a.y == b.y) { return a.z < b.z; } return a.y < b.y; } void add(int u, int w) { for (int i = u; i <= mx; i += lowbit(i)) { t[i] += w; } } int ask(int u) { int sum = 0; for (int i = u; i; i -= lowbit(i)) { sum += t[i]; } return sum; } void cdq(int l, int r) { if (l == r) return ; int mid = (l + r) >> 1; cdq(l, mid); cdq(mid + 1, r); sort(s2 + l, s2 + mid + 1, cmp2); sort(s2 + mid + 1, s2 + r + 1, cmp2); int i, j = l; for (i = mid + 1; i <= r; ++ i) { while (s2[i].y >= s2[j].y && j <= mid) { // 一旦不符合,先统计,然后右指针右移一位。 add(s2[j].z, s2[j].cnt); // 插入 ++ j; } s2[i].ans += ask(s2[i].z); } for (i = l; i < j; ++ i) { // 清空数组,memset 常数太大。 add(s2[i].z, -s2[i].cnt); } }
最后就是处理答案了,完整代码:
/* The code was written by yifan, and yifan is neutral!!! */ #include <bits/stdc++.h> using namespace std; typedef long long ll; #define lowbit(i) (i & (-i)) template<typename T> inline T read() { T x = 0; bool fg = 0; char ch = getchar(); while (ch < '0' || ch > '9') { fg |= (ch == '-'); ch = getchar(); } while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); } return fg ? ~x + 1 : x; } const int N = 1e5 + 5; int n, k, mx, m; int t[N << 1], res[N]; struct node { int x, y, z, cnt, ans; } s1[N], s2[N]; bool cmp1(node a, node b) { if (a.x == b.x) { if (a.y == b.y) { return a.z < b.z; } else return a.y < b.y; } return a.x < b.x; } bool cmp2(node a, node b) { if (a.y == b.y) { return a.z < b.z; } return a.y < b.y; } void add(int u, int w) { for (int i = u; i <= mx; i += lowbit(i)) { t[i] += w; } } int ask(int u) { int sum = 0; for (int i = u; i; i -= lowbit(i)) { sum += t[i]; } return sum; } void cdq(int l, int r) { if (l == r) return ; int mid = (l + r) >> 1; cdq(l, mid); cdq(mid + 1, r); sort(s2 + l, s2 + mid + 1, cmp2); sort(s2 + mid + 1, s2 + r + 1, cmp2); int i, j = l; for (i = mid + 1; i <= r; ++ i) { while (s2[i].y >= s2[j].y && j <= mid) { add(s2[j].z, s2[j].cnt); ++ j; } s2[i].ans += ask(s2[i].z); } for (i = l; i < j; ++ i) { add(s2[i].z, -s2[i].cnt); } } int main() { n = read<int>(), k = read<int>(); mx = k; for (int i = 1, x, y, z; i <= n; ++ i) { x = read<int>(), y = read<int>(), z = read<int>(); s1[i].x = x, s1[i].y = y, s1[i].z = z; } sort(s1 + 1, s1 + n + 1, cmp1); int top = 0; for (int i = 1; i <= n; ++ i) { ++ top; if (s1[i].x != s1[i + 1].x || s1[i].y != s1[i + 1].y || s1[i].z != s1[i + 1].z) { s2[++ m].x = s1[i].x; s2[m].y = s1[i].y; s2[m].z = s1[i].z; s2[m].cnt = top; top = 0; } } cdq(1, m); for (int i = 1; i <= m; ++ i) { res[s2[i].ans + s2[i].cnt - 1] += s2[i].cnt; } for (int i = 0; i < n; ++ i) { printf("%d\n", res[i]); } return 0; }
P5094 [USACO04OPEN] MooFest G 加强版
一道比较好的入门题。统计答案的时候稍微麻烦一些。
/* The code was written by yifan, and yifan is neutral!!! */ #include <bits/stdc++.h> using namespace std; typedef long long ll; template<typename T> inline T read() { T x = 0; bool fg = 0; char ch = getchar(); while (ch < '0' || ch > '9') { fg |= (ch == '-'); ch = getchar(); } while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }