首先那个大树是由若干小树组成的,那么把每一颗小树抽象成一个点,这颗大树就变成节点为M+1的树了。同时定义树上相邻点(x,fa[x])的距离为小树x的根到小树fa[x]的根的距离。那么此时查询两个点x,y流程如下:
1.得到x所在小树u,y所在小树v;
2.若u=v,直接查询;否则令w=lca(u,v)(这里的lca为大树中的lca),然后分lca=u(v)和lca!=u且lca!=v两种情况讨论一下就好了。
那么查询某个点x所在小树可以二分查找,然后就变成查询子树第k大的经典主席树问题。
lca用树链剖分不知道会不会快一点(⊙﹏⊙)b。
AC代码如下:
<pre name="code" class="cpp">#include<iostream>
#include<cstdio>
#include<cstring>
#define N 100005
#define M 2000005
#define ll long long
using namespace std;
int n,m,trtot,dfsclk,sum[M],ls[M],rs[M],rt[N],lf[N],rg[N],id[N];
struct node{
int id,rt,fa; ll l,r;
}a[N];
ll read(){
ll x=0; char ch=getchar();
while (ch<'0' || ch>'9') ch=getchar();
while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }
return x;
}
void ins(int l,int r,int x,int &y,int k){
y=++trtot; sum[y]=sum[x]+1;
if (l==r) return; int mid=(l+r)>>1;
if (k<=mid){ rs[y]=rs[x]; ins(l,mid,ls[x],ls[y],k); }
else{ ls[y]=ls[x]; ins(mid+1,r,rs[x],rs[y],k); }
}
int qry(int k,int z){
int l=1,r=n,mid,tmp,x=rt[lf[k]-1],y=rt[rg[k]];
while (l<r){
mid=(l+r)>>1; tmp=sum[ls[y]]-sum[ls[x]];
if (z<=tmp){ r=mid; x=ls[x]; y=ls[y]; }
else{ l=mid+1; z-=tmp; x=rs[x]; y=rs[y]; }
}
return l;
}
int getid(ll x,int ed){
int l=1,r=ed+1,mid;
while (l+1<r){
mid=(l+r)>>1;
if (a[mid].l<=x) l=mid; else r=mid;
}
return l;
}
struct tree_node{
int tot,fst[N],pnt[N<<1],len[N<<1],nxt[N<<1],fa[N];
int sz[N],son[N],anc[N];
ll d[N];
void add(int x,int y,int z){
pnt[++tot]=y; len[tot]=z; nxt[tot]=fst[x]; fst[x]=tot;
}
void dfs(int x){
int p; sz[x]=1;
for (p=fst[x]; p; p=nxt[p]){
int y=pnt[p];
if (y!=fa[x]){
fa[y]=x; d[y]=d[x]+len[p];
dfs(y); sz[x]+=sz[y];
if (sz[y]>sz[son[x]]) son[x]=y;
}
}
}
void nbr(int x,int tp){
lf[x]=rg[x]=++dfsclk; id[dfsclk]=x; anc[x]=tp; int p;
if (son[x]){
nbr(son[x],tp); rg[x]=rg[son[x]];
}
for (p=fst[x]; p; p=nxt[p]){
int y=pnt[p];
if (y!=fa[x] && y!=son[x]){
nbr(y,y); rg[x]=rg[y];
}
}
}
void divide(int x,int tp){
anc[x]=tp; int p;
if (son[x]) divide(son[x],tp);
for (p=fst[x]; p; p=nxt[p]){
int y=pnt[p];
if (y!=fa[x] && y!=son[x]) divide(y,y);
}
}
int lca(int x,int y){
for (; anc[x]!=anc[y]; x=fa[anc[x]])
if (d[anc[x]]<d[anc[y]]) swap(x,y);
return (d[x]<d[y])?x:y;
}
int gettp(int x,int y){
int z;
for (; anc[x]!=anc[y]; x=fa[anc[x]]) z=anc[x];
return (x==y)?z:son[y];
}
void build(){
int i;
for (i=1; i<=n; i++) ins(1,n,rt[i-1],rt[i],id[i]);
}
ll dist(int x,int y){
return d[x]+d[y]-(d[lca(x,y)]<<1);
}
}t1,t2;
int main(){
n=read(); m=read(); int cas=read(),i,z; ll x,y;
for (i=1; i<n; i++){
x=read(); y=read();
t1.add(x,y,1); t1.add(y,x,1);
}
t1.dfs(1); t1.nbr(1,1); t1.build();
a[1].id=1; a[1].rt=1; a[1].l=1; a[1].r=n;
for (i=1; i<=m; i++){
x=read(); y=read();
a[i+1].rt=x; a[i+1].id=i+1;
a[i+1].l=a[i].r+1; a[i+1].r=a[i].r+t1.sz[x];
z=getid(y,i); a[i+1].fa=y=qry(a[z].rt,y-a[z].l+1);
t2.add(z,i+1,t1.d[y]-t1.d[a[z].rt]+1);
}
t2.dfs(1); t2.divide(1,1); int u,v,w; ll ans;
while (cas--){
x=read(); y=read();
u=getid(x,m+1); v=getid(y,m+1); w=t2.lca(u,v);
x=qry(a[u].rt,x-a[u].l+1); y=qry(a[v].rt,y-a[v].l+1);
if (u==v) printf("%lld\n",t1.dist(x,y)); else{
if (u==w){ swap(u,v); swap(x,y); }
if (v==w){
v=t2.gettp(u,w); ans=t1.d[x]-t1.d[a[u].rt]+t2.d[u]-t2.d[v];
x=a[v].fa; ans+=t1.dist(x,y)+1;
} else{
ans=t1.d[x]-t1.d[a[u].rt]+t1.d[y]-t1.d[a[v].rt]+t2.dist(u,v);
u=t2.gettp(u,w); v=t2.gettp(v,w);
x=a[u].fa; y=a[v].fa;
ans-=(t1.d[t1.lca(x,y)]-t1.d[a[w].rt])<<1;
}
printf("%lld\n",ans);
}
}
return 0;
}
by lych
2016.4.20