问题 B: 树上逆序对
时间限制: 1 Sec 内存限制: 256 MB
题目描述
一天,Chika 在研究关于所谓的树上逆序对的问题,你能帮助她吗?她会给你一棵有根树,这棵树有 n 个 结点,被编号为 1~ n,1 号结点是根。每个点有一个权值,i 号结点的权值为 a[i]。如果 u 是 v 的祖先结点, 并且 a[u] > a[v],那么 (u,v) 被称作一个“** 逆序对 **”。
Chika 会给你 m 个任务,包含两种类型:
1 u x : 向树中添加一个新结点,其父亲为 u,权值为 x。执行完这个操作后,树的结点总数增加 1,因此该 结点的编号为 n+1(n 为添加这个点之前树中的总结点数)。
2 u : 你需要回答如果删除以 u 为根的子树,树的剩余部分的逆序对数。任意两个类型 2 的任务之间互相 独立。
你需要完成所有任务。
输入
输入文件的第一行包含两个整数 n (1≤n≤10^5) 和 m (1≤m≤10^5),代表树中初始状态的结点总数以及 任务总数。
第二行包含 n 个整数,第 i 个整数是 a[i] (1≤ a[i] ≤10^9),代表 i 号结点的权值。
第三行包含 n−1 个整数,第 i 个整数是 i+1 号结点的父结点。
接下来是 m 行,每一行描述了一个任务,满足以下两种格式之一:
1 u x
2 u
对于每个任务,u 都满足不小于 1 且不超过该时刻树中结点的总数。对于每个类型 1 的任务,x 都满足 1≤ x ≤109。
你可以在本题的描述中找到任务的含义。
输出
对于每个类型2的任务,你应该输出一行,仅包含一个数字,代表你的答案。
样例输入
10 10 10 2 9 2 8 7 8 3 1 8 1 1 2 3 5 2 5 6 2 1 1 6 2 3 1 9 5 2 1 2 7 1 12 2 2 6 1 11 13 2 10 2 11
样例输出
5 0 21 11 26 26
线段树+离线处理
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
#define N 200003
#define mid (l+r>>1)
#define lc (d<<1)
#define rc (d<<1|1)
typedef long long ll;
int fv[N], val[N], bit[N], cnt, tid[N], en[N], tot;
int ucnt;
vector<int> V[N];
struct Q{
int op, u, x;
}q[N/2];
ll tr[N<<2];
void Push(int d) {tr[d] = tr[lc]+tr[rc];}
void build(int d, int l, int r) {
if (l == r) {
tr[d] = 0;
return ;
}
build(lc, l, mid);
build(rc, mid+1, r);
Push(d);
}
void update(int d, int l, int r, int pos, int ad) {
if (l == r) {
tr[d] += ad;
return;
}
if (pos <= mid) update(lc, l, mid, pos, ad);
else update(rc, mid+1, r, pos, ad);
Push(d);
}
ll query(int d, int l, int r, int L, int R) {
if (L > R) return 0;
if (l == L && r == R) {
return tr[d];
}
if (R <= mid) return query(lc, l, mid, L, R);
else if (L > mid) return query(rc, mid+1, r, L, R);
else return query(lc, l, mid, L, mid)+query(rc, mid+1, r, mid+1, R);
}
void dfs(int u, int f) {
tid[u] = ++tot;
int rk = lower_bound(bit+1, bit+ucnt, val[u])-bit;
fv[u] = query(1, 1, ucnt, rk+1, ucnt);
update(1, 1, ucnt, rk, 1);
int i, v;
for (i = 0;i < V[u].size();i++) {
v = V[u][i];
if (v == f) continue;
dfs(v, u);
}
update(1, 1, ucnt, rk, -1);
en[u] = tot;
}
int main() {
int n, m, u, v, i, j;
while (~scanf("%d%d", &n, &m)) {
for (i = 1;i <= n;i++) scanf("%d", val+i);
for (i = 1;i <= n+m;i++) V[i].clear();
for (i = 2;i <= n;i++) {
scanf("%d", &u);
V[u].push_back(i);
}
cnt = n;
for (i = 1;i <= m;i++) {
scanf("%d%d", &q[i].op, &q[i].u);
if (q[i].op == 1) {
scanf("%d", &q[i].x);
V[q[i].u].push_back(++cnt);
val[cnt] = q[i].x;
}
}
for (i = 1;i <= cnt;i++) {
bit[i] = val[i];
}
sort(bit+1, bit+1+cnt);
ucnt = unique(bit+1, bit+1+cnt)-bit-1;
build(1, 1, ucnt);
tot = 0;
dfs(1, 0);
build(1, 1, tot);
for (i = 1;i <= n;i++) {
update(1, 1, tot, tid[i], fv[i]);
}
int id = n;
for (i = 1;i <= m;i++) {
if (q[i].op == 1) {
id++;
update(1, 1, tot, tid[id], fv[id]);
} else {
printf("%lld\n", tr[1]-query(1, 1, tot, tid[q[i].u], en[q[i].u]));
}
}
}
}