读错题了,然后写了一个树上 LIS,应该是对的吧......
code:
#include <bits/stdc++.h>
#define N 200005
#define LL long long
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
struct seg
{
#define lson t[x].ls
#define rson t[x].rs
int tot;
struct node
{
int ls,rs,maxx;
}t[N*50];
int newnode() { return ++tot; }
void update(int &x,int l,int r,int p,int v)
{
if(!x) x=newnode();
t[x].maxx=max(t[x].maxx,v);
if(l==r) return;
int mid=(l+r)>>1;
if(p<=mid) update(lson,l,mid,p,v);
else update(rson,mid+1,r,p,v);
}
int merge(int x,int y)
{
if(!x||!y) return x+y;
int now=newnode();
t[now].maxx=max(t[x].maxx,t[y].maxx);
t[now].ls=merge(t[x].ls,t[y].ls);
t[now].rs=merge(t[x].rs,t[y].rs);
return now;
}
int query(int x,int l,int r,int L,int R)
{
if(!x) return 0;
if(l>=L&&r<=R) return t[x].maxx;
int mid=(l+r)>>1,re=0;
if(L<=mid) re=max(re, query(lson,l,mid,L,R));
if(R>mid) re=max(re, query(rson,mid+1,r,L,R));
return re;
}
#undef lson
#undef rson
}in,de;
multiset<int>s1,s2;
multiset<int>::iterator it1,it2;
int n,edges;
int rt_in[N],rt_de[N],val[N],ans[N],A[N],hd[N],to[N<<1],nex[N<<1],max1[N],max2[N];
void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void dfs(int u,int ff)
{
int max_in=1,max_de=1;
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff) continue;
dfs(v,u);
max_in=max(max_in, in.query(rt_in[v],1,n,val[u],n)+1);
max_de=max(max_de, de.query(rt_de[v],1,n,1,val[u])+1);
}
int tl=0;
s1.clear();
s2.clear();
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff) continue;
max1[++tl]=in.query(rt_in[v],1,n,val[u],n);
max2[tl]=de.query(rt_de[v],1,n,1,val[u]);
s1.insert(max1[tl]);
s2.insert(max2[tl]);
}
if(tl>1)
{
for(int i=1;i<=tl;++i)
{
s1.erase(s1.lower_bound(max1[i]));
s2.erase(s2.lower_bound(max2[i]));
it1=s1.end(), it2=s2.end();
it1--,it2--;
ans[u]=max(ans[u], max2[i]+(*it1)+1);
ans[u]=max(ans[u], max1[i]+(*it2)+1);
s1.insert(max1[i]);
s2.insert(max2[i]);
}
}
ans[u]=max(ans[u],max(max_in,max_de));
in.update(rt_in[u],1,n,val[u],max_in);
de.update(rt_de[u],1,n,val[u],max_de);
if(ff)
{
rt_in[ff]=in.merge(rt_in[ff],rt_in[u]);
rt_de[ff]=de.merge(rt_de[ff],rt_de[u]);
}
}
int main()
{
// setIO("input");
int i,j;
scanf("%d",&n);
for(i=1;i<=n;++i) scanf("%d",&val[i]), A[i]=val[i];
sort(A+1,A+1+n);
for(i=1;i<=n;++i) val[i]=lower_bound(A+1,A+1+n,val[i])-A;
for(i=2;i<=n;++i)
{
int x;
scanf("%d",&x);
add(x,i),add(i,x);
}
dfs(1,0);
int tmp=0;
for(i=1;i<=n;++i) tmp=max(tmp,ans[i]);
printf("%d\n",tmp);
return 0;
}