F[x][i]表示x的子树中取的数字<=i的最大值,线段树合并优化DP
写得很难看,并不知道好看的写法
#include<cstdio>
#include<algorithm>
using namespace std;
int cnt,n,Num,ans,last[200005],tag[10000005],ls[10000005],rs[13000005],E[200005],a[200005],Fa[200005],root[200005],ANS[200005],tree[10000005];
struct node{
int to,next;
}e[1000005];
void add(int a,int b){
e[++cnt].to=b;
e[cnt].next=last[a];
last[a]=cnt;
}
void push_down(int x){
if (ls[x]) tag[ls[x]]+=tag[x],tree[ls[x]]+=tag[x],tree[ls[x]]=max(tree[ls[x]],tree[x]);
if (rs[x]) tag[rs[x]]+=tag[x],tree[rs[x]]+=tag[x],tree[rs[x]]=max(tree[rs[x]],tree[x]);
tag[x]=0;
}
int merge(int x,int y){
if (!x) return y;
if (!y) return x;
push_down(x); push_down(y);
if (!ls[x]) ls[x]=ls[y],tree[ls[x]]+=tree[x],tag[ls[x]]+=tree[x];
else if (!ls[y]) tree[ls[x]]+=tree[y],tag[ls[x]]+=tree[y];
else ls[x]=merge(ls[x],ls[y]);
if (!rs[x]) rs[x]=rs[y],tree[rs[x]]+=tree[x],tag[rs[x]]+=tree[x];
else if (!rs[y]) tree[rs[x]]+=tree[y],tag[rs[x]]+=tree[y];
else rs[x]=merge(rs[x],rs[y]);
tree[x]+=tree[y];
return x;
}
int query(int t,int l,int r,int x){
if (!t) return 0;
if (l==r) return tree[t];
push_down(t);
int mid=(l+r)>>1;
if (x<=mid) return max(tree[t],query(ls[t],l,mid,x));
else return max(tree[t],query(rs[t],mid+1,r,x));
}
void insert(int &t,int l,int r,int x,int y,int Val){
if (l>y || r<x) return;
if (!t) t=++cnt;
if (l>=x && r<=y){
tree[t]=max(tree[t],Val);
return;
}
push_down(t);
int mid=(l+r)>>1;
insert(ls[t],l,mid,x,y,Val);
insert(rs[t],mid+1,r,x,y,Val);
}
void solve(int x){
for (int i=last[x]; i; i=e[i].next){
int V=e[i].to;
solve(V);
root[x]=merge(root[x],root[V]);
}
int Key=query(root[x],1,Num,a[x]-1);
insert(root[x],1,Num,a[x],Num,Key+1);
}
int main(){
scanf("%d",&n);
for (int i=1; i<=n; i++) {
scanf("%d%d",&a[i],&Fa[i]);
if (Fa[i]) add(Fa[i],i);
}
E[++n]=-1e9;
for (int i=1; i<=n; i++) E[i]=a[i];
sort(E+1,E+n+1);
Num=unique(E+1,E+n+1)-E-1;
for (int i=1; i<=n; i++) a[i]=lower_bound(E+1,E+Num+1,a[i])-E;
cnt=0;
solve(1);
for (int i=1; i<=Num; i++) ANS[i]=query(root[1],1,Num,i);
int ans=0;
for (int i=1; i<=Num; i++) ans=max(ans,ANS[i]);
printf("%d\n",ans);
return 0;
}