CDQ分治
简介
什么是 cdq 分治呢?,其实他是一种思想而不是具体的算法(就和 dp 是一样的),因此 cdq 分治涵盖的范围相当的广泛,由于这样的思路最早是被陈丹琦引入国内的,所以就叫 cdq 分治了。
现在 oi 界对于 cdq 分治这个思想的拓展十分广泛,但是这些都叫 cdq 的东西其实原理和写法上并不相同不过我们可以大概的将它们分为三类:
- cdq 分治解决和点对有关的问题
- cdq 分治优化 1D/1D 动态规划的转移
- 通过 cdq 分治,将一些动态问题转化为静态问题
CDQ 分治解决和点对有关的问题
这类问题一般是给你一个长度为 n 的序列,然后让你统计有一些特性的点对 ( i , j ) (i,j) (i,j)有多少个,又或者说是找到一对点 ( i , j ) (i,j) (i,j)使得一些函数的值最大之类的问题。
那么 cdq 分治基于这样一个算法流程解决这类问题:
设方法 s o l v e ( l , r ) solve(l,r) solve(l,r)寻找 l ≤ i ≤ r , l ≤ j ≤ r l \leq i \leq r,l \leq j \leq r l≤i≤r,l≤j≤r的点对。
因此我们的分治方案如下:
- 将点对分类。先寻找区间中点 m i d mid mid,则第一大类点对为分布在 l ≤ i ≤ m i d , l ≤ j ≤ m i d l \leq i \leq mid,l \leq j \leq mid l≤i≤mid,l≤j≤mid,或 m i d + 1 ≤ i ≤ r , m i d + 1 ≤ j ≤ r mid+1 \leq i \leq r,mid+1 \leq j \leq r mid+1≤i≤r,mid+1≤j≤r的点对,即分布在两个分段上;第二大类点对为分布在 l ≤ i ≤ m i d , m i d + 1 ≤ j ≤ r l \leq i \leq mid,mid+1 \leq j \leq r l≤i≤mid,mid+1≤j≤r或 m i d + 1 ≤ i ≤ r , l ≤ j ≤ m i d mid + 1 \leq i \leq r,l \leq j \leq mid mid+1≤i≤r,l≤j≤mid。
- 针对第一大类点对,我们可以通过继续分治解决。但是针对第二大类点对,根据不同的问题有不同的写法。因此,CDQ分治的核心在于解决第二大类点对的求解。
二维偏序
对于点集合 S S S中的元素 S i = ( a , b ) S_{i}=(a,b) Si=(a,b),请问有多少个点对 ( S i , S j ) (S_{i},S_{j}) (Si,Sj)使得 a i < a j a_{i} < a_{j} ai<aj并且 b i < b j b_{i} < b_{j} bi<bj。
将键 a a a看成是元素的位置,等价于求解序列中逆序对的个数问题。
注意这里序列是有序的,集合是无序的。
首先,我们在CDQ分治之前,在外部按照 a a a键排序集合中的元素,使其变成一个序列。
因此,对于第二类点对,我们先对两个子序列分别按照 b b b键排序,看似打乱了 a a a键,但是我们只针对第二类点对,因此无论两个子序列分别怎么打乱,只有两个子序列中的元素不交换,不等式 a i < a j a_{i} < a_{j} ai<aj恒成立, S i S j S_{i}S_{j} SiSj分别为左右子序列中的元素。然后就变成了一维偏序处理,我们可以使用树状数组等方案实现。
由此可见,CDQ分治的核心在于分离条件的相关性。
ll cdq(int A[], int n)
{
if (n == 1)
return 0;
int mid = n / 2, i = 0, j = mid;
ll ans = cdq(A, mid) + cdq(A + mid, n - mid);
sort(A, A + mid, greater<int>()), sort(A + mid, A + n, greater<int>());
for (; j < n; ++j)
{
while (i < mid && A[i] > A[j])
i++;
ans += i;
}
return ans;
}
三维偏序
和二维偏序类似,增加一维键即可。
先在外部对 a a a键排序,然后针对第二类点对,我们可以先对两个子序列按照 b b b键排序,然后枚举所有的 j j j,使用双指针,将所有满足 b i > b j b_{i} > b_{j} bi>bj的点按照键 c c c插入到树状数组中,然后求和到 c j c_{j} cj即可。
注意,如果存在等于号,要先合并相同元素,否则如果将同种元素分到两个区间,则会造成计数减少。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define FR freopen("in.txt", "r", stdin)
#define FW freopen("out.txt", "w", stdout)
int nn, k, n = 0;
int cnt[200010];
int inline lowbit(int x)
{
return x & -x;
}
struct TreeArray
{
int t[200005];
void add(int idx, int val)
{
while (idx <= k)
{
t[idx] += val;
idx += lowbit(idx);
}
}
int psum(int idx)
{
int ans = 0;
while (idx > 0)
{
ans += t[idx];
idx -= lowbit(idx);
}
return ans;
}
} ta;
struct Element
{
int a;
int b;
int c;
int num;
int ans;
} ek[100005], e[100005];
bool cmpa(const Element &x, const Element &y)
{
if (x.a == y.a)
if (x.b == y.b)
return x.c < y.c;
else
return x.b < y.b;
return x.a < y.a;
}
bool cmpb(const Element &x, const Element &y)
{
if (x.b == y.b)
return x.c < y.c;
return x.b < y.b;
}
void CDQ(int l, int r)
{
if (l == r - 1)
{
return;
}
int mid = (l + r) / 2;
CDQ(l, mid);
CDQ(mid, r);
sort(e + l, e + mid, cmpb);
sort(e + mid, e + r, cmpb);
int i = l;
int j = mid;
while (j < r)
{
while (i < mid && e[i].b <= e[j].b)
{
// insert
ta.add(e[i].c, e[i].num);
i++;
}
// query
e[j].ans += ta.psum(e[j].c);
j++;
}
for (int t = l; t < i; t++)
{
ta.add(e[t].c, -e[t].num);
}
}
int main()
{
scanf("%d %d", &nn, &k);
for (int i = 0; i < nn; i++)
{
scanf("%d %d %d", &ek[i].a, &ek[i].b, &ek[i].c);
}
sort(ek, ek + nn, cmpa);
// merge
for (int i = 0; i < nn; i++)
{
int j = i + 1;
int cnt = 1;
while (j < nn && ek[j].a == ek[i].a && ek[j].b == ek[i].b && ek[j].c == ek[i].c)
{
cnt++;
j++;
}
i = j - 1;
e[n] = ek[i];
e[n].num = cnt;
n++;
}
CDQ(0, n);
for (int i = 0; i < n; i++)
{
cnt[e[i].ans + e[i].num - 1] += e[i].num;
}
for (int i = 0; i < nn; i++)
{
printf("%d\n", cnt[i]);
}
return 0;
}
CDQ套CDQ
待补充。
例题
点对统计,很容易想到CDQ分治。时间复杂度 O ( n log 2 n ) O(n \log^2 n) O(nlog2n)
#include <bits/stdc++.h>
using namespace std;
#define FR freopen("in.txt", "r", stdin)
#define FW freopen("out.txt", "w", stdout)
typedef long long ll;
struct Res
{
ll xi;
ll vi;
} res[20005];
bool cmp1(const Res &a, const Res &b)
{
return a.vi > b.vi;
}
bool cmp2(const Res &a, const Res &b)
{
return a.xi < b.xi;
}
ll cdq(int l, int r)
{
if (l == r - 1)
{
return 0;
}
int mid = (l + r) >> 1;
ll ans = 0;
ans += cdq(l, mid);
ans += cdq(mid, r);
sort(res + l, res + mid, cmp2);
sort(res + mid, res + r, cmp2);
ll sum = 0;
for (int i = mid; i < r; i++)
{
sum += res[i].xi;
}
int rp = mid - 1;
ll psum = 0;
for (int lp = l; lp < mid; lp++)
{
while (rp + 1 < r && res[rp + 1].xi <= res[lp].xi)
{
rp++;
psum += res[rp].xi;
}
ans += res[lp].vi * (res[lp].xi * (rp - mid + 1) - psum + (sum - psum) - res[lp].xi * (r - rp - 1));
}
return ans;
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++)
{
scanf("%lld %lld", &res[i].vi, &res[i].xi);
}
sort(res, res + n, cmp1);
printf("%lld", cdq(0, n));
return 0;
}