Description
小A想做一棵很大的树,但是他手上的材料有限,只好用点小技巧了。开始,小A只有一棵结点数为N的树,结
点的编号为1,2,…,N,其中结点1为根;我们称这颗树为模板树。小A决定通过这棵模板树来构建一颗大树。构建过
程如下:(1)将模板树复制为初始的大树。(2)以下(2.1)(2.2)(2.3)步循环执行M次(2.1)选择两个数字a,b,
其中1<=a<=N,1<=b<=当前大树的结点数。(2.2)将模板树中以结点a为根的子树复制一遍,挂到大树中结点b的下
方(也就是说,模板树中的结点a为根的子树复制到大树中后,将成为大树中结点b的子树)。(2.3)将新加入大树
的结点按照在模板树中编号的顺序重新编号。例如,假设在进行2.2步之前大树有L个结点,模板树中以a为根的子
树共有C个结点,那么新加入模板树的C个结点在大树中的编号将是L+1,L+2,…,L+C;大树中这C个结点编号的大小
顺序和模板树中对应的C个结点的大小顺序是一致的。下面给出一个实例。假设模板树如下图:
根据第(1)步,初始的大树与模板树是相同的。在(2.1)步,假设选择了a=4,b=3。运行(2.2)和(2.3)后,得到新的
大树如下图所示
现在他想问你,树中一些结点对的距离是多少。
Input
第一行三个整数:N,M,Q,以空格隔开,N表示模板树结点数,M表示第(2)中的循环操作的次数,Q 表示询问数
量。接下来N-1行,每行两个整数 fr,to,表示模板树中的一条树边。再接下来M行,每行两个整数x,to,表示将模
板树中 x 为根的子树复制到大树中成为结点to的子树的一次操作。再接下来Q行,每行两个整数fr,to,表示询问
大树中结点 fr和 to之间的距离是多少。
Output
输出Q行,每行一个整数,第 i行是第 i个询问的答案。
Sample Input
1 4
1 3
4 2
4 5
4 3
3 2
6 9
1 8
5 3
Sample Output
3
3
HINT
经过两次操作后,大树变成了下图所示的形状:
结点6到9之间经过了6条边,所以距离为6;类似地,结点1到8之间经过了3条边;结点5到3之间也经过了3条边。
题解:
可以发现这个题需要的操作就是求两点的lca和某个点的深度.
我们首先建出模板树,
然后把每次复制产生的新块缩成一个点,模板树也缩成一个点,这样可以建出缩点后的大树.
两个操作的思路都是先在大树上走到目标块,然后在模板树里处理.
我们需要知道每次进入模板树的点.
这个分析一下就是查询模板树一个子树中某个点的排名和第k小的点.
用dfs序和主席树处理即可.
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
#define N 100010
using namespace std;
int siz1[N],siz2[N],cnt,sz,now,point[N],bl1[N],bl2[N],fa[N],pos[N],fa2[N],next[N<<2],head[N];
int ls[N*20],rs[N*20],root[N],sum[N*20],a[N],n,m,Q,x,y,dep1[N],num,be[N],ed[N],son[N];
ll dep2[N],b[N],xu[N];
struct use{int st,en,v;}e[N<<2];
ll read(){
ll x(0);char ch=getchar();
while (ch<'0'||ch>'9') ch=getchar();
while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return x;
}
void add(int x,int y){
next[++cnt]=point[x];point[x]=cnt;
e[cnt].st=x;e[cnt].en=y;
}
void link(int x,int y,int v){
next[++cnt]=head[x];head[x]=cnt;
e[cnt].st=x;e[cnt].en=y;e[cnt].v=v;
}
void dfs_a1(int x){
siz1[x]=1;be[x]=++num;pos[num]=x;
for (int i=point[x];i;i=next[i])
if (e[i].en!=fa[x]){
dep1[e[i].en]=dep1[x]+1;fa[e[i].en]=x;
dfs_a1(e[i].en);siz1[x]+=siz1[e[i].en];
}
ed[x]=num;
}
void dfs_a2(int x,int c){
bl1[x]=c;int k(0);
for (int i=point[x];i;i=next[i])
if (e[i].en!=fa[x]&&siz1[e[i].en]>siz1[k]) k=e[i].en;
if (!k) return;dfs_a2(k,c);
for (int i=point[x];i;i=next[i])
if (e[i].en!=fa[x]&&e[i].en!=k) dfs_a2(e[i].en,e[i].en);
}
void dfs_b1(int x){
siz2[x]=1;
for (int i=head[x];i;i=next[i])
if (e[i].en!=fa2[x]){
dep2[e[i].en]=dep2[x]+e[i].v;fa2[e[i].en]=x;
dfs_b1(e[i].en);siz2[x]+=siz2[e[i].en];
}
}
void dfs_b2(int x,int c){
bl2[x]=c;int k=0;
for (int i=head[x];i;i=next[i])
if (e[i].en!=fa2[x]&&siz2[e[i].en]>siz2[k]) k=e[i].en;
if (!k) return;son[x]=k;dfs_b2(k,c);
for (int i=head[x];i;i=next[i])
if (e[i].en!=fa2[x]&&e[i].en!=k) dfs_b2(e[i].en,e[i].en);
}
void insert(int &y,int x,int l,int r,int v){
y=++sz;int mid=(l+r)>>1;
ls[y]=ls[x];rs[y]=rs[x];sum[y]=sum[x]+1;
if (l==r) return;
if (v<=mid) insert(ls[y],ls[x],l,mid,v);
else insert(rs[y],rs[x],mid+1,r,v);
}
int query(int x,int y,int l,int r,int k){
if (l==r) return l;
int mid=(l+r)>>1;
if (sum[ls[y]]-sum[ls[x]]>=k) return query(ls[x],ls[y],l,mid,k);
else return query(rs[x],rs[y],mid+1,r,k-sum[ls[y]]+sum[ls[x]]);
}
int ask(int x,int k){
return query(root[be[x]-1],root[ed[x]],1,n,k);
}
int find(ll x){
int l=1,r=now;
while (l<=r){
int mid=(l+r)>>1;
if (xu[mid]<x) l=mid+1;
else r=mid-1;
}
return l-1;
}
int getrank(int x,int y,int l,int r,int v){
if (l==r) return sum[y]-sum[x];
int mid=(l+r)>>1;
if (v<=mid) return getrank(ls[x],ls[y],l,mid,v);
else return getrank(rs[x],rs[y],mid+1,r,v)+sum[ls[y]]-sum[ls[x]];
}
int rank(int x,int k){
return getrank(root[be[x]-1],root[ed[x]],1,n,k);
}
int lca1(int x,int y){
while (bl1[x]!=bl1[y]){
if (dep1[bl1[x]]<dep1[bl1[y]]) swap(x,y);
x=fa[bl1[x]];
}
if (dep1[x]<dep1[y]) return x;
else return y;
}
ll lca2(ll u,ll v){
int x=find(u),y=find(v);
while (bl2[x]!=bl2[y]){
if (dep2[bl2[x]]<dep2[bl2[y]]) swap(x,y),swap(u,v);
x=bl2[x];u=b[x];x=fa2[x];
}
if (x!=y){
if (dep2[x]>dep2[y]) swap(x,y),swap(u,v);
y=x;v=b[son[y]];
}
return xu[x]+rank(a[x],lca1(ask(a[x],u-xu[x]),ask(a[y],v-xu[y])));
}
ll getdeep(ll x){
int t=find(x);
return dep2[t]+dep1[ask(a[t],x-xu[t])]-dep1[a[t]];
}
ll getans(ll x,ll y){
return getdeep(x)+getdeep(y)-2*getdeep(lca2(x,y));
}
int main(){
n=read();m=read();Q=read();
for (int i=1;i<n;i++){
x=read();y=read();
add(x,y);add(y,x);
}
dfs_a1(1);dfs_a2(1,0);
for (int i=1;i<=n;i++) insert(root[i],root[i-1],1,n,pos[i]);
now=1;a[1]=1;
for (int i=2;i<=m+1;i++){
a[i]=read();b[i]=read();
xu[i]=xu[i-1]+siz1[a[i-1]];
int t=find(b[i]),v=dep1[ask(a[t],b[i]-xu[t])]-dep1[a[t]]+1;
link(i,t,v);link(t,i,v);now++;
}
dfs_b1(1);dfs_b2(1,1);
for (int i=1;i<=Q;i++){
ll x=read(),y=read();
printf("%lld\n",getans(x,y));
}
}