upd:整体二分也顺便学了,比较简单不写。
概述
CDQ
分治主要用于解决偏序问题。
通常是先按某一维排序,再递归处理分出来的左子问题对右子问题的答案,最后合并。
经典问题-逆序对
同时也是经典的二维偏序问题。
记属性组 ( a , b ) (a,b) (a,b),其中 a a a 表示位置, b b b 表示值。
那么就是求 a i < a j a_i<a_j ai<aj 且 b i > b j b_i>b_j bi>bj 的个数。
然后就是对 a a a 排序,对 b b b 树状数组从前往后扫一遍统计即可。
时间复杂度 O ( n log n ) \mathcal O(n \log n) O(nlogn)。
三维偏序
有 n n n 个元素,第 i i i 个元素有 a i , b i , c i a_i,b_i,c_i ai,bi,ci 三个属性,设 f ( i ) f(i) f(i) 表示满足 a j ≤ a i a_j \leq a_i aj≤ai 且 b j ≤ b i b_j \leq b_i bj≤bi 且 c j ≤ c i c_j \leq c_i cj≤ci 且 j ≠ i j \ne i j=i 的 j j j 的数量。
对于 d ∈ [ 0 , n ) d \in [0, n) d∈[0,n),求 f ( i ) = d f(i) = d f(i)=d 的数量。
1 ≤ n ≤ 1 0 5 1 \leq n \leq 10^5 1≤n≤105, 1 ≤ a i , b i , c i ≤ k ≤ 2 × 1 0 5 1 \leq a_i,b_i,c_i \leq k \leq 2\times 10^5 1≤ai,bi,ci≤k≤2×105。
先按 a a a 维排序,再将左、右子区间按 b b b 维排序。
虽然现在 a a a 的顺序被打乱了,但是前半边还是都小于后半边的,所以要是只计算前半边对后半边的偏序关系,是不会受到 a a a 的影响的。
考虑到 b b b 已经排序了,直接双指针枚举,树状数组统计 c c c 维贡献即可。
时间复杂度 O ( n log 2 n ) \mathcal O(n \log^2 n) O(nlog2n)。
#include <bits/stdc++.h>
using namespace std;
inline int read()
{
int x = 0, f = 1;
char c = getchar();
while(c < '0' || c > '9')
{
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9')
{
x = x * 10 + c - '0';
c = getchar();
}
return x * f;
}
const int _ = 5e5 + 10;
int n, n_, k, tot, cnt[_], c[_];
struct node
{
int x, y, z, ans, w;
} a[_], b[_];
inline bool cmp(node u, node v)
{
if(u.x == v.x)
{
if(u.y == v.y) return u.z < v.z;
return u.y < v.y;
}
return u.x < v.x;
}
inline void update(int x, int val)
{
for(int i = x; i <= k; i += i & -i) c[i] += val;
}
inline int query(int x)
{
int res = 0;
for(int i = x; i; i -= i & -i) res += c[i];
return res;
}
void cdq(int l, int r)
{
if(l == r) return;
int mid = l + r >> 1;
cdq(l, mid), cdq(mid + 1, r);
int pl = l, pr = mid + 1, k = l;
while(pl <= mid && pr <= r)
{
if(a[pl].y <= a[pr].y)
{
update(a[pl].z, a[pl].w);
b[k++] = a[pl++];
}
else
{
a[pr].ans += query(a[pr].z);
b[k++] = a[pr++];
}
}
while(pl <= mid)
{
update(a[pl].z, a[pl].w);
b[k++] = a[pl++];
}
while(pr <= r)
{
a[pr].ans += query(a[pr].z);
b[k++] = a[pr++];
}
for(int i = l; i <= mid; ++i) update(a[i].z, -a[i].w);
for(int i = l; i <= r; ++i) a[i] = b[i];
}
signed main()
{
n_ = read(), k = read();
for(int i = 1; i <= n_; ++i)
b[i].x = read(), b[i].y = read(), b[i].z = read();
sort(b + 1, b + n_ + 1, cmp);
int c = 0;
for(int i = 1; i <= n_; ++i)
{
c++;
if(b[i].x != b[i + 1].x || b[i].y != b[i + 1].y || b[i].z != b[i + 1].z)
a[++n] = b[i], a[n].w = c, c = 0;
}
cdq(1, n);
for(int i = 1; i <= n; ++i)
cnt[a[i].ans + a[i].w - 1] += a[i].w;
for(int i = 0; i < n_; ++i)
printf("%d\n", cnt[i]);
return 0;
}
例题
P3157 [CQOI2011]动态逆序对
现在给出 1 ∼ n 1\sim n 1∼n 的一个排列 a a a,按照某种顺序依次删除 m m m 个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。
1 ≤ n ≤ 1 0 5 , 1 ≤ m ≤ 5 × 1 0 4 1 \leq n \leq 10^5,1 \leq m \leq 5 \times 10^4 1≤n≤105,1≤m≤5×104。
对于每一个被删的三元组 ( i , t i , a i ) (i,t_i,a_i) (i,ti,ai)(分别表示第 i i i 个数的位置,删除时间及权值),消失的逆序对 ( i , j ) (i,j) (i,j) 为:
- 满足 j < i , a j > a i , t j > t i j<i,a_j>a_i,t_j>t_i j<i,aj>ai,tj>ti 的 j j j。
- 满足 j > i , a j < a i , t j > t i j > i,a_j<a_i,t_j>t_i j>i,aj<ai,tj>ti 的 j j j。
明显的三维偏序问题,做一次 CDQ
即可。
时间复杂度 O ( n log 2 n ) \mathcal O(n \log^2 n) O(nlog2n)。
#include <bits/stdc++.h>
using namespace std;
inline int read()
{
int x = 0, f = 1;
char c = getchar();
while(c < '0' || c > '9')
{
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9')
{
x = x * 10 + c - '0';
c = getchar();
}
return x * f;
}
const int _ = 5e5 + 10;
int n, m, tot, cnt[_], r[_], c[_];
struct node
{
int x, y, z, ans; // x:位置 y: 时间 z: 权值
} a[_], b[_];
inline bool cmp(node u, node v)
{
if(u.x == v.x)
{
if(u.y == v.y) return u.z < v.z;
return u.y < v.y;
}
return u.x < v.x;
}
inline void update(int x, int val)
{
for(int i = x; i <= 1e5 + 10; i += i & -i) c[i] += val;
}
inline int query(int x)
{
int res = 0;
for(int i = x; i; i -= i & -i) res += c[i];
return res;
}
void cdq(int l, int r)
{
if(l == r) return;
int mid = l + r >> 1;
cdq(l, mid), cdq(mid + 1, r);
int ql = l, qr = mid + 1, k = l;
while(ql <= mid && qr <= r)
{
if(a[ql].y > a[qr].y)
{
update(a[ql].z, 1);
b[k++] = a[ql++];
}
else
{
a[qr].ans += query(1e5 + 9) - query(a[qr].z);
b[k++] = a[qr++];
}
}
while(ql <= mid)
{
update(a[ql].z, 1);
b[k++] = a[ql++];
}
while(qr <= r)
{
a[qr].ans += query(1e5 + 9) - query(a[qr].z);
b[k++] = a[qr++];
}
for(int i = l; i <= mid; ++i) update(a[i].z, -1);
ql = l, qr = mid + 1, k = l;
while(ql <= mid && qr <= r)
{
if(a[ql].y < a[qr].y)
{
update(a[qr].z, 1);
b[k++] = a[qr++];
}
else
{
a[ql].ans += query(a[ql].z);
b[k++] = a[ql++];
}
}
while(qr <= r)
{
update(a[qr].z, 1);
b[k++] = a[qr++];
}
while(ql <= mid)
{
a[ql].ans += query(a[ql].z);
b[k++] = a[ql++];
}
for(int i = mid + 1; i <= r; ++i) update(a[i].z, -1);
for(int i = l; i <= r; ++i) a[i] = b[i];
}
bool cmp2(node u, node v)
{
return u.y < v.y;
}
signed main()
{
n = read(), m = read();
for(int i = 1; i <= n; ++i)
{
a[i].x = i;
a[i].z = read();
r[a[i].z] = i;
}
for(int i = 1, p; i <= m; ++i)
a[r[read()]].y = i;
for(int i = 1; i <= n; ++i)
if(a[i].y == 0) a[i].y = m + 1;
long long res = 0;
for(int i = 1; i <= n; ++i)
{
res += query(n + 1) - query(a[i].z);
update(a[i].z, 1);
}
for(int i = 1; i <= n; ++i)
update(a[i].z, -1);
sort(a + 1, a + n + 1, cmp);
cdq(1, n);
sort(a + 1, a + n + 1, cmp2);
for(int i = 1; i <= m; ++i)
{
printf("%lld\n", res);
res -= a[i].ans;
}
return 0;
}
习题:
P4093 [HEOI2016/TJOI2016]序列
#include <bits/stdc++.h>
using namespace std;
inline int read()
{
int x = 0, f = 1;
char c = getchar();
while(c < '0' || c > '9')
{
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9')
{
x = x * 10 + c - '0';
c = getchar();
}
return x * f;
}
const int _ = 1e5 + 10;
int n, m, c[_], a[_], mx[_], mn[_], p[_], f[_], ans;
inline void update(int x, int val)
{
for(int i = x; i <= 1e5; i += i & -i)
c[i] = max(c[i], val);
}
inline void clear(int x)
{
for(int i = x; i <= 1e5; i += i & -i)
c[i] = 0;
}
inline int query(int x)
{
int res = 0;
for(int i = x; i; i -= i & -i)
res = max(res, c[i]);
return res;
}
inline bool cmp1(int x, int y)
{
return mx[x] < mx[y];
}
inline bool cmp2(int x, int y)
{
return a[x] < a[y];
}
void cdq(int l, int r)
{
if(l == r)
{
f[l] = max(f[l], 1);
return;
}
int mid = l + r >> 1;
cdq(l, mid);
for(int i = l; i <= r; ++i)
p[i] = i;
sort(p + l, p + mid + 1, cmp1);
sort(p + mid + 1, p + r + 1, cmp2);
int j = l;
for(int i = mid + 1; i <= r; ++i)
{
while(j <= mid && mx[p[j]] <= a[p[i]])
{
update(a[p[j]], f[p[j]]);
++j;
}
f[p[i]] = max(f[p[i]], query(mn[p[i]]) + 1);
}
for(int i = l; i <= mid; ++i)
clear(a[i]);
cdq(mid + 1, r);
}
signed main()
{
n = read(), m = read();
for(int i = 1; i <= n; ++i)
{
a[i] = read();
mx[i] = mn[i] = a[i];
}
for(int i = 1, x, y; i <= m; ++i)
{
x = read(), y = read();
mx[x] = max(mx[x], y);
mn[x] = min(mn[x], y);
}
cdq(1, n);
for(int i = 1; i <= n; ++i)
ans = max(ans, f[i]);
printf("%d\n", ans);
return 0;
}
P8253 [NOI Online 2022 提高组] 如何正确地排序
#include <bits/stdc++.h>
using namespace std;
#define int long long
inline int read()
{
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
x = x * 10 + c - '0', c = getchar();
return x * f;
}
const int _ = 2e5 + 10, base = 2e5 + 1;
int n, m, a[4][_], ans, c[_ << 1], tot;
struct abc
{
int op, x, y, val;
inline bool operator < (const abc &t) const
{
if(x != t.x) return x < t.x;
return op < t.op;
}
} q[_ << 1];
inline void update(int x)
{
for(int i = x; i < 2 * base; i += i & -i) c[i]++;
}
inline int query(int x)
{
int res = 0;
for(int i = x; i; i -= i & -i) res += c[i];
return res;
}
int calc(int x, int y, int z)
{
int res = 0;
memset(c, 0, sizeof c), tot = 0;
for(int i = 1; i <= n; ++i)
{
q[++tot] = {0ll, a[x][i] - a[y][i] + (x > y), a[y][i] - a[z][i] + (y > z), 0ll};
q[++tot] = {1ll, a[y][i] - a[x][i], a[z][i] - a[y][i], a[y][i]};
}
sort(q + 1, q + tot + 1);
for(int i = 1; i <= tot; ++i)
{
if(!q[i].op) update(q[i].y + base);
else res += q[i].val * query(q[i].y + base);
}
return res;
}
signed main()
{
m = read(), n = read();
for(int i = 0; i < m; ++i)
for(int j = 1; j <= n; ++j) a[i][j] = read();
for(int i = m; i <= 3; ++i)
for(int j = 1; j <= n; ++j) a[i][j] = a[i - m][j];
for(int i = 0; i <= 3; ++i)
for(int j = 1; j <= n; ++j) ans += 2 * n * a[i][j];
for(int i = 0; i <= 3; ++i)
for(int j = 0; j <= 3; ++j)
for(int k = 0; k <= 3; ++k)
if(i != j && j != k && k != i) ans -= calc(i, j, k);
printf("%lld\n", ans);
return 0;
}