题意
给定一棵
n
n
个节点的树。 有 个询问,每次给出两个点集,大小分别为
k1,k2
k
1
,
k
2
,求从两个集合中分别选出一个点,它们的
LCA
L
C
A
的深度最大为多少。
1≤n,m≤105
1
≤
n
,
m
≤
10
5
1≤∑k≤105
1
≤
∑
k
≤
10
5
思路
解法一:dfs序+LCA
dfs d f s 序和 LCA L C A 间,其实不难发现有一个联系,对于一个节点 u u , 序比 u u 小或比 大的点中,与 u u 的 序越近的点,与 u u 的 深度越深。在 dfs d f s 的时候也不难看出这个性质。那么只用对一个点集的 dfs d f s 序排序并维护一个从 dfs d f s 序到编号的映射。每次插入另一个集合的点都与第一个 dfn d f n 比它小的点和 dfn d f n 第一个比它大于等于的点求 LCA L C A ,更新最大深度。
解法二:二分答案
正解好像是这种方法,思路比较清楚。因为越浅的点越容易成为父亲。所以可以二分答案,枚举 LCA L C A 深度(但这里并不是“最近公共祖先”,而是一般的“公共祖先”)。然后把其中一个集合的所有点不到这个深度的都跳到这个深度,并把跳到的点 mark m a r k , check c h e c k 函数则是跳另一个集合中的点至这个深度,若发现跳完后的点已被 mark m a r k ,则返回 yes y e s 。
解法三:树上倍增+dfs序+树状数组
比较神奇的做法,复杂度与二分答案相同,方法也有些类似。先把其中一个集合的点的 dfs d f s 序在树状数组中标记,然后对于另一个集合中的每一个点,倍增的找到第一个能包含集合一种点的深度,并更新答案。不难发现,解法二中判祖孙也可以这种 dfs d f s 序区间包含的方法,但是由于只用判两个点所以直接跳,而这种解法需要判一个点是不是一些点中其中一个的祖先,所以使用树状数组。
解法四:分块决策
这其实不算是一个正常的解法,由于点集大小总和不超过
105
10
5
,数据组数大,点集的大小就小;点集大小大,数据组数就小。那我们可以结合两种暴力:
1.
1.
暴力枚举两个点集的点并求
LCA
L
C
A
,复杂度
O(k1k2)
O
(
k
1
k
2
)
.
2,
2
,
标记一个集合的所有父亲,另一个集合在向父亲走,走到标记过的点更新答案,复杂度
O(n)
O
(
n
)
.
当点集大时,采用暴力二;否则采用暴力一。
代码
解法一:dfs序+LCA
#include<iostream>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define FOR(i,x,y) for(register int i=(x);i<=(y);i++)
#define DOR(i,x,y) for(register int i=(x);i>=(y);i--)
#define lowbit(x) ((x)&-(x))
#define N 100003
typedef long long LL;
using namespace std;
template<const int maxn,const int maxm>struct Linked_list
{
int head[maxn],to[maxm],nxt[maxm],tot;
void clear(){memset(head,-1,sizeof(head));tot=0;}
void add(int u,int v){to[++tot]=v,nxt[tot]=head[u];head[u]=tot;}
#define EOR(i,G,u) for(int i=G.head[u];~i;i=G.nxt[i])
};
Linked_list<N,N<<1>G;
int fa[N][20],lv[N],L[N],ori[N],bin[N],f[N],n,m,ord;
void dfs(int u,int d)
{
L[u]=++ord,ori[ord]=u;
lv[u]=d;
EOR(i,G,u)
{
int v=G.to[i];
if(v!=fa[u][0])
{
fa[v][0]=u;
dfs(v,d+1);
}
}
}
void make_fa()
{
FOR(j,1,19)
FOR(i,1,n)
if(~fa[i][j-1])
fa[i][j]=fa[fa[i][j-1]][j-1];
}
void jmp(int &x,int stp)
{
while(stp)
{
x=fa[x][bin[lowbit(stp)]];
stp^=lowbit(stp);
}
}
int LCA(int a,int b)
{
if(lv[a]>lv[b])jmp(a,lv[a]-lv[b]);
else if(lv[b]>lv[a])jmp(b,lv[b]-lv[a]);
if(a==b)return a;
DOR(i,19,0)
if(fa[a][i]!=fa[b][i])
{
a=fa[a][i];
b=fa[b][i];
}
return fa[a][0];
}
int main()
{
FOR(i,2,N-3)bin[i]=bin[i>>1]+1;
while(~scanf("%d%d",&n,&m))
{
ord=0;
memset(fa,-1,sizeof(fa));
G.clear();
FOR(i,1,n-1)
{
int u,v;
scanf("%d%d",&u,&v);
G.add(u,v);
G.add(v,u);
}
dfs(1,1);
make_fa();
while(m--)
{
int k1,k2,nod,ans=0;
scanf("%d",&k1);
FOR(i,1,k1)scanf("%d",&f[i]),f[i]=L[f[i]];
sort(f+1,f+1+k1);
scanf("%d",&k2);
FOR(i,1,k2)
{
scanf("%d",&nod);
int t1=lower_bound(f+1,f+1+k1,L[nod])-f,t2=t1-1;
if(t1<=k1)ans=max(ans,lv[LCA(nod,ori[f[t1]])]);
if(t2>=1)ans=max(ans,lv[LCA(nod,ori[f[t2]])]);
}
printf("%d\n",ans);
}
}
return 0;
}
解法二:二分答案
#include<iostream>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define FOR(i,x,y) for(register int i=(x);i<=(y);i++)
#define DOR(i,x,y) for(register int i=(x);i>=(y);i--)
#define lowbit(x) ((x)&-(x))
#define N 100003
typedef long long LL;
using namespace std;
template<const int maxn,const int maxm>struct Linked_list
{
int head[maxn],to[maxm],nxt[maxm],tot;
void clear(){memset(head,-1,sizeof(head));tot=0;}
void add(int u,int v){to[++tot]=v,nxt[tot]=head[u];head[u]=tot;}
#define EOR(i,G,u) for(int i=G.head[u];~i;i=G.nxt[i])
};
Linked_list<N,N<<1>G;
int fa[N][20],lv[N],bin[N],n,m;
int a[N],b[N],k1,k2,mark[N];
void dfs(int u,int d)
{
lv[u]=d;
EOR(i,G,u)
{
int v=G.to[i];
if(v!=fa[u][0])
{
fa[v][0]=u;
dfs(v,d+1);
}
}
}
void make_fa()
{
FOR(j,1,19)
FOR(i,1,n)
if(~fa[i][j-1])
fa[i][j]=fa[fa[i][j-1]][j-1];
}
int jmp(int x,int stp)
{
while(stp)
{
x=fa[x][bin[lowbit(stp)]];
stp^=lowbit(stp);
}
return x;
}
bool check(int dep,int T)
{
FOR(i,1,k1)if(lv[a[i]]>=dep)mark[jmp(a[i],lv[a[i]]-dep)]=T;
FOR(i,1,k2)if(lv[b[i]]>=dep&&mark[jmp(b[i],lv[b[i]]-dep)]==T)return 1;
return 0;
}
int main()
{
FOR(i,2,N-3)bin[i]=bin[i>>1]+1;
while(~scanf("%d%d",&n,&m))
{
memset(mark,0,sizeof(mark));
memset(fa,-1,sizeof(fa));
G.clear();
FOR(i,1,n-1)
{
int u,v;
scanf("%d%d",&u,&v);
G.add(u,v);G.add(v,u);
}
dfs(1,1);
make_fa();
int T=0;
while(m--)
{
scanf("%d",&k1);
FOR(i,1,k1)scanf("%d",&a[i]);
scanf("%d",&k2);
FOR(i,1,k2)scanf("%d",&b[i]);
int L=1,R=N;
while(L<R)
{
int mid=L+R+1>>1;
if(check(mid,++T))
L=mid;
else R=mid-1;
}
printf("%d\n",L);
}
}
return 0;
}
解法三:树上倍增+dfs序+树状数组
#include<iostream>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define FOR(i,x,y) for(int i=(x);i<=(y);i++)
#define DOR(i,x,y) for(int i=(x);i>=(y);i--)
#define lowbit(x) ((x)&-(x))
#define N 100003
typedef long long LL;
using namespace std;
template<const int maxn,const int maxm>struct Linked_list
{
int head[maxn],to[maxm],nxt[maxm],tot;
void clear(){memset(head,-1,sizeof(head));tot=0;}
void add(int u,int v){to[++tot]=v,nxt[tot]=head[u];head[u]=tot;}
#define EOR(i,G,u) for(int i=G.head[u];~i;i=G.nxt[i])
};
struct BinaryIndexedTree
{
int c[N],n;
void reset(int _){memset(c,0,sizeof(c));n=_;}
void update(int k,int val)
{
while(k<=n)
{
c[k]+=val;
k+=lowbit(k);
}
}
int _query(int k)
{
int res=0;
while(k>0)
{
res+=c[k];
k^=lowbit(k);
}
return res;
}
int query(int L,int R){return _query(R)-_query(L-1);}
}BIT;
Linked_list<N,N<<1>G;
int fa[N][20],lv[N],bin[N],n,m,ord,ans;
int a[N],L[N],R[N],k1,k2;
void dfs(int u,int d)
{
L[u]=++ord;
lv[u]=d;
EOR(i,G,u)
{
int v=G.to[i];
if(v!=fa[u][0])
{
fa[v][0]=u;
dfs(v,d+1);
}
}
R[u]=ord;
}
void make_fa()
{
FOR(j,1,19)
FOR(i,1,n)
if(~fa[i][j-1])
fa[i][j]=fa[fa[i][j-1]][j-1];
}
int jmp(int x,int stp)
{
while(stp)
{
x=fa[x][bin[lowbit(stp)]];
stp^=lowbit(stp);
}
return x;
}
int main()
{
FOR(i,2,N-3)bin[i]=bin[i>>1]+1;
while(~scanf("%d%d",&n,&m))
{
memset(fa,-1,sizeof(fa));
BIT.reset(n);
ord=0;
G.clear();
FOR(i,1,n-1)
{
int u,v;
scanf("%d%d",&u,&v);
G.add(u,v);G.add(v,u);
}
dfs(1,1);
make_fa();
while(m--)
{
int ans=0;
scanf("%d",&k1);
FOR(i,1,k1)scanf("%d",&a[i]);
FOR(i,1,k1)BIT.update(L[a[i]],1);
scanf("%d",&k2);
FOR(i,1,k2)
{
int nod;
scanf("%d",&nod);
if(BIT.query(L[nod],R[nod]))
{
ans=max(ans,lv[nod]);
continue;
}
DOR(j,19,0)
if(~fa[nod][j]&&!BIT.query(L[fa[nod][j]],R[fa[nod][j]]))
nod=fa[nod][j];
ans=max(ans,lv[nod]-1);
}
FOR(i,1,k1)BIT.update(L[a[i]],-1);
printf("%d\n",ans);
}
}
return 0;
}
解法四:分块决策
#include<iostream>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define FOR(i,x,y) for(int i=(x);i<=(y);i++)
#define DOR(i,x,y) for(int i=(x);i>=(y);i--)
#define lowbit(x) ((x)&-(x))
#define N 100003
typedef long long LL;
using namespace std;
template<const int maxn,const int maxm>struct Linked_list
{
int head[maxn],to[maxm],nxt[maxm],tot;
void clear(){memset(head,-1,sizeof(head));tot=0;}
void add(int u,int v){to[++tot]=v,nxt[tot]=head[u];head[u]=tot;}
#define EOR(i,G,u) for(int i=G.head[u];~i;i=G.nxt[i])
};
Linked_list<N,N<<1>G;
int fa[N][20],lv[N],bin[N],n,m,t,ans;
int a[N],b[N],k1,k2,mark[N];
void dfs(int u,int d)
{
lv[u]=d;
EOR(i,G,u)
{
int v=G.to[i];
if(v!=fa[u][0])
{
fa[v][0]=u;
dfs(v,d+1);
}
}
}
void make_fa()
{
FOR(j,1,19)
FOR(i,1,n)
if(~fa[i][j-1])
fa[i][j]=fa[fa[i][j-1]][j-1];
}
int jmp(int x,int stp)
{
while(stp)
{
x=fa[x][bin[lowbit(stp)]];
stp^=lowbit(stp);
}
return x;
}
int LCA(int a,int b)
{
if(lv[a]>lv[b])a=jmp(a,lv[a]-lv[b]);
else if(lv[b]>lv[a])b=jmp(b,lv[b]-lv[a]);
if(a==b)return a;
DOR(i,19,0)
if(fa[a][i]!=fa[b][i])
{
a=fa[a][i];
b=fa[b][i];
}
return fa[a][0];
}
void dfs_a(int u)
{
if(mark[u]==t)return;
mark[u]=t;
dfs_a(fa[u][0]);
}
void dfs_b(int u)
{
if(mark[u]==t-1)ans=max(ans,lv[u]);
if(mark[u]==t)return;
mark[u]=t;
dfs_b(fa[u][0]);
}
void viol()
{
FOR(i,1,k1)FOR(j,1,k2)ans=max(ans,lv[LCA(a[i],b[j])]);
}
void optim()
{
t++;
FOR(i,1,k1)dfs_a(a[i]);
t++;
FOR(i,1,k2)dfs_b(b[i]);
}
int main()
{
FOR(i,2,N-3)bin[i]=bin[i>>1]+1;
while(~scanf("%d%d",&n,&m))
{
memset(mark,0,sizeof(mark));
G.clear();
FOR(i,1,n-1)
{
int u,v;
scanf("%d%d",&u,&v);
G.add(u,v);G.add(v,u);
}
dfs(1,1);
make_fa();
t=0;
while(m--)
{
ans=0;
scanf("%d",&k1);
FOR(i,1,k1)scanf("%d",&a[i]);
scanf("%d",&k2);
FOR(i,1,k2)scanf("%d",&b[i]);
if(20*k1*k2<=n)
viol();
else optim();
printf("%d\n",ans);
}
}
return 0;
}