题面
题意
给出两棵树A,B,其根节点都是1,现在删掉A树中的一条边(u,v),然后将B树的所有满足p,q两点中有一点在u,v的公共子树中,另外一个点不在的边(p,q)删去,然后再按上述规则来删掉A树中的所有满足条件的边,输出删边顺序。
做法
首先可以对于一棵树上的边(p,q),如果它要被删掉,当且仅当另外一棵树上路径pq中的任意一条边被删去,而这条路径可以通过树链剖分转化为线段树上的几段区间。
这样问题就转化为了,每个数对应着至多log段区间,这些区间中任何一个点被删去,这个点就会被删去。可以用线段树来维护这个,对这些区间都打上可持久化标记,这样最多有nloglog个标记,当要删除某个点时,在线段树上找到这个点,并把它在线段树上的路径上的所有点中的所有标记的点都删光,并清楚所有标记,如此反复即可。
代码
#include<bits/stdc++.h>
#define N 200100
using namespace std;
int n;
bool A,B;
vector<int>ans[2];
struct Tree
{
int tt,son[N],top[N],fa[N],deep[N],in[N];
bool gg[N];
struct Node
{
int ls,rs;
vector<int>num;
}node[N<<1];
vector<int>to[N];
void add(int u,int v){to[u].push_back(v);}
int dfs(int now)
{
int i,t,tmp,res=1,mx=0;
for(i=0;i<to[now].size();i++)
{
t=to[now][i];
fa[t]=now;
deep[t]=deep[now]+1;
res+=tmp=dfs(t);
if(tmp>mx)
{
mx=tmp;
son[now]=t;
}
}
return res;
}
void Dfs(int now)
{
int i,t;
in[now]=++tt;
if(son[now])
{
top[son[now]]=top[now];
Dfs(son[now]);
}
for(i=0;i<to[now].size();i++)
{
t=to[now][i];
if(t==son[now]) continue;
top[t]=t;
Dfs(t);
}
}
void build(int now,int l,int r)
{
if(l==r) return;
int mid=((l+r)>>1);
node[now].ls=++tt;
build(tt,l,mid);
node[now].rs=++tt;
build(tt,mid+1,r);
}
void add(int now,int l,int r,int u,int v,int w)
{
if(u<=l&&r<=v)
{
node[now].num.push_back(w);
return;
}
int mid=((l+r)>>1);
if(u<=mid) add(node[now].ls,l,mid,u,v,w);
if(mid<v) add(node[now].rs,mid+1,r,u,v,w);
}
void del(int now,int l,int r,int u)
{
int i,t;
for(i=0;i<node[now].num.size();i++)
{
t=node[now].num[i];
if(!gg[t])
{
gg[t]=1;
ans[A].push_back(t);
}
}
node[now].num.clear();
if(l==r) return;
int mid=((l+r)>>1);
if(u<=mid) del(node[now].ls,l,mid,u);
else del(node[now].rs,mid+1,r,u);
}
void pre()
{
int i,j;
deep[1]=top[1]=1;
dfs(1);
Dfs(1);
build(tt=1,1,n);
}
void ad(int u,int v,int w)
{
for(;top[u]!=top[v];)
{
if(deep[top[u]]<deep[top[v]]) swap(u,v);
add(1,1,n,in[top[u]],in[u],w);
u=fa[top[u]];
}
if(deep[u]>deep[v]) swap(u,v);
if(u!=v) add(1,1,n,in[u]+1,in[v],w);
}
void del(int u){del(1,1,n,in[u]);}
}tree[2];
int main()
{
int i,j,p,q;
cin>>n;
for(i=2;i<=n;i++) scanf("%d",&p),tree[0].add(p,i);
for(i=2;i<=n;i++) scanf("%d",&p),tree[1].add(p,i);
tree[0].pre(),tree[1].pre();
for(i=2;i<=n;i++)
{
tree[0].ad(tree[1].fa[i],i,i);
tree[1].ad(tree[0].fa[i],i,i);
}
cin>>p,ans[A=1].push_back(p+1),tree[1].gg[p+1]=1;
for(i=0;;i^=1)
{
swap(A,B);
puts(i&1?"Red":"Blue");
sort(ans[B].begin(),ans[B].end());
for(j=0;j<ans[B].size();j++) printf("%d ",ans[B][j]-1);puts("");
ans[A].clear();
for(j=0;j<ans[B].size();j++) tree[i].del(ans[B][j]);
if(!ans[A].size()) return 0;
}
}