LCIS是The LCIS on the Tree的简易版。
3308是一道简单的线段树+区间和并。
4718是将LCIS在树上运行,因此,就需要在LCIS的基础上再加一个树链剖分。
LCIS的代码:
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int N=1e5+5;
struct Node{
int l,r,lsub,rsub,sub;
} tr[N<<2];
int v[N];
void pushup(int i)
{
tr[i].sub=max(tr[i<<1].sub,tr[i<<1|1].sub);
if(v[tr[i<<1].r]<v[tr[i<<1|1].l]) tr[i].sub=max(tr[i].sub,tr[i<<1].rsub+tr[i<<1|1].lsub);
tr[i].lsub=tr[i<<1].lsub;
if(tr[i<<1].lsub==(tr[i<<1].r-tr[i<<1].l+1)&&v[tr[i<<1].r]<v[tr[i<<1|1].l])
tr[i].lsub+=tr[i<<1|1].lsub;
tr[i].rsub=tr[i<<1|1].rsub;
if(tr[i<<1|1].rsub==(tr[i<<1|1].r-tr[i<<1|1].l+1)&&v[tr[i<<1].r]<v[tr[i<<1|1].l])
tr[i].rsub+=tr[i<<1].rsub;
}
void build(int i,int l,int r)
{
tr[i].l=l; tr[i].r=r; tr[i].lsub=tr[i].sub=tr[i].rsub=1;
if(l==r) return ;
int mid=(l+r)>>1;
build(i<<1,l,mid);
build(i<<1|1,mid+1,r);
pushup(i);
}
void update(int i,int p,int c)
{
if(tr[i].l==tr[i].r)
{
v[tr[i].l]=c;
return ;
}
int mid=(tr[i].l+tr[i].r)>>1;
if(p<=mid) update(i<<1,p,c);
else update(i<<1|1,p,c);
pushup(i);
}
int query(int i,int l,int r)
{
if(tr[i].l==l&&tr[i].r==r)
return tr[i].sub;
int mid=(tr[i].l+tr[i].r)>>1;
if(r<=mid) return query(i<<1,l,r);
else if(l>mid) return query(i<<1|1,l,r);
else
{
int ans=max(query(i<<1,l,mid),query(i<<1|1,mid+1,r));
if(v[tr[i<<1].r]<v[tr[i<<1|1].l])
ans=max(ans,min(tr[i<<1].rsub,mid-l+1)+min(tr[i<<1|1].lsub,r-mid));
return ans;
}
}
int main()
{
int t,n,m,a,b;
char op[5];
scanf("%d",&t);
while(t--)
{
scanf("%d%d",&n,&m);
for(int i=0;i<n;i++)
scanf("%d",&v[i]);
build(1,0,n-1);
while(m--)
{
scanf("%s%d%d",op,&a,&b);
if(op[0]=='Q')
printf("%d\n",query(1,a,b));
else
update(1,a,b);
}
}
return 0;
}
除了树链剖分,若路径为一个子节点到另一个子节点时,则可能存在从出发子节点到最小公共祖先节点路径的递减最长左序列+最小公共祖先节点到目的子节点路径的递增最长做序列。因此线段树需要对每一个区间记录递增、递减两种情况。
The LCIS on the Tree的代码:
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int N=1e5+5;
struct Edge{
int to,next;
} e[N<<1];
struct Tree{
int l,r;
int il,ir,ics; //递增时最长左区间、最长右区间和最长区间
int dl,dr,dcs; //递减时最长左区间、最长右区间和最长区间
Tree(){l=r=il=ir=ics=dl=dr=dcs=0;}
} tr[N<<2];
int tot,head[N];
int son[N],id[N],fa[N],deep[N],size[N],top[N],rk[N],cnt;
int val[N];
void add(int u,int v)
{
e[++tot].to=v;
e[tot].next=head[u];
head[u]=tot;
}
void dfs1(int u,int f,int d)
{
deep[u]=d; fa[u]=f; size[u]=1; son[u]=0;
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]) continue;
dfs1(v,u,d+1);
size[u]+=size[v];
if(size[v]>size[son[u]])
son[u]=v;
}
}
void dfs2(int u,int t)
{
id[u]=++cnt; rk[cnt]=u; top[u]=t;
if(!son[u]) return ;
dfs2(son[u],t);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v!=son[u]&&v!=fa[u])
dfs2(v,v);
}
}
void pushup(Tree l,Tree r,Tree &f)
{
if(r.ics==0) {f=l; return ;}
f.l=l.l; f.r=r.r;
f.il=l.il; f.ir=r.ir; f.ics=max(l.ics,r.ics);
if(val[rk[l.r]]<val[rk[r.l]])
{
f.ics=max(f.ics,l.ir+r.il);
if(l.il==l.r-l.l+1)
f.il+=r.il;
if(r.ir==r.r-r.l+1)
f.ir+=l.ir;
}
f.dl=l.dl; f.dr=r.dr; f.dcs=max(l.dcs,r.dcs);
if(val[rk[l.r]]>val[rk[r.l]])
{
f.dcs=max(f.dcs,l.dr+r.dl);
if(l.dl==l.r-l.l+1)
f.dl+=r.dl;
if(r.dr==r.r-r.l+1)
f.dr+=l.dr;
}
}
void build(int i,int l,int r)
{
tr[i].l=l; tr[i].r=r;
if(l==r)
{
tr[i].il=tr[i].ir=tr[i].ics=1;
tr[i].dl=tr[i].dr=tr[i].dcs=1;
return ;
}
int mid=(l+r)>>1;
build(i<<1,l,mid);
build(i<<1|1,mid+1,r);
pushup(tr[i<<1],tr[i<<1|1],tr[i]);
}
Tree query(int i,int l,int r)
{
if(tr[i].l==l&&tr[i].r==r)
return tr[i];
int mid=(tr[i].l+tr[i].r)>>1;
if(r<=mid) return query(i<<1,l,r);
else if(l>mid) return query(i<<1|1,l,r);
else
{
Tree ans;
pushup(query(i<<1,l,mid),query(i<<1|1,mid+1,r),ans);
return ans;
}
}
//保证当ans1和ans2共同存在与在路径上时,ans1取最长递减区间,ans2取最长递增区间
void solve(int a,int b)
{
int flag=1;
Tree ans1,ans2;
if(deep[top[a]]<deep[top[b]]) {swap(a,b); swap(ans1,ans2); flag^=1;}
while(top[a]!=top[b])
{
pushup(query(1,id[top[a]],id[a]),ans1,ans1);
a=fa[top[a]];
if(deep[top[a]]<deep[top[b]]) {swap(a,b); swap(ans1,ans2); flag^=1;}
}
if(id[a]>id[b]) {swap(a,b); swap(ans1,ans2); flag^=1;}
pushup(query(1,id[a],id[b]),ans2,ans2);
if(!flag) swap(ans1,ans2);
int res=max(ans1.dcs,ans2.ics);
if(val[rk[ans1.l]]<val[rk[ans2.l]]) res=max(res,ans1.dl+ans2.il);
printf("%d\n",res);
}
int main()
{
int t,n,p,q,u,v,casenum=1;
scanf("%d",&t);
while(t--)
{
memset(head,0,sizeof(head));
cnt=tot=0;
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&val[i]);
for(int i=2;i<=n;i++)
{
scanf("%d",&p);
add(i,p); add(p,i);
}
dfs1(1,0,1);
dfs2(1,1);
build(1,1,n);
printf("Case #%d:\n",casenum++);
scanf("%d",&q);
while(q--)
{
scanf("%d%d",&u,&v);
solve(u,v);
}
if(t)
printf("\n");
}
return 0;
}