题意
给定两棵有n个节点的树,每个节点上有一对数(x,y),表示图G中的一条边。对于每一个x,求出两棵树中x到根路径上所有边在图G中构成的子图的连通块个数。
n≤10000
n
≤
10000
分析
考虑在第一棵树中提取关键点,使得每个点到他最近的关键祖先的距离不超过
n−−√
n
,显然这样的点数不超过
O(n−−√)
O
(
n
)
。
考虑把每个点的询问都放到离他最近的关键祖先中,然后在第一棵树中dfs,同时维护一个可以合并的并查集。
在到达一个关键点时处理属于这个关键点的询问。
如何处理呢?我们可以在第二棵树上dfs一次,同时继续维护并查集,如果到达了一个询问点,则把该询问点在第一棵树中到其关键祖先的路径都遍历一遍,然后就可以得到答案了。
时间复杂度是
O(nn−−√logn)
O
(
n
n
l
o
g
n
)
。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
const int N=10005;
int n,m,cnt,top,ls1[N],ls2[N],bel[N],B,stack[N],f[N],s[N],fa[N],dis[N],now,ans[N],sum;
bool cho[N];
struct edge{int to,next;}e[N*4];
struct data{int x,y;}g1[N],g2[N];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int find(int x)
{
if (f[x]==x) return x;
else return find(f[x]);
}
void add1(int u,int v)
{
e[++cnt].to=v;e[cnt].next=ls1[u];ls1[u]=cnt;
e[++cnt].to=u;e[cnt].next=ls1[v];ls1[v]=cnt;
}
void add2(int u,int v)
{
e[++cnt].to=v;e[cnt].next=ls2[u];ls2[u]=cnt;
e[++cnt].to=u;e[cnt].next=ls2[v];ls2[v]=cnt;
}
void pre(int x)
{
dis[x]=cho[x]=0;
for (int i=ls1[x];i;i=e[i].next)
{
if (e[i].to==fa[x]) continue;
fa[e[i].to]=x;
pre(e[i].to);
if (!cho[e[i].to]) dis[x]=std::max(dis[x],dis[e[i].to]+1);
}
if (dis[x]==B) cho[x]=1;
}
void merge(int x,int y)
{
x=find(x);y=find(y);
if (x==y) return;
if (s[x]>s[y]) std::swap(x,y);
sum--;f[x]=y;s[y]+=s[x];stack[++top]=x;
}
void ret(int tmp)
{
while (top>tmp)
{
int x=stack[top];top--;
s[f[x]]-=s[x];f[x]=x;sum++;
}
}
void dfs2(int x,int fro)
{
int tmp=top;
merge(g2[x].x,g2[x].y);
for (int i=ls2[x];i;i=e[i].next) if (e[i].to!=fro) dfs2(e[i].to,x);
if (bel[x]!=now) {ret(tmp);return;}
for (int y=x;y!=now;y=fa[y]) merge(g1[y].x,g1[y].y);
ans[x]=sum;
ret(tmp);
}
void dfs1(int x)
{
int tmp=top;
merge(g1[x].x,g1[x].y);
for (int i=ls1[x];i;i=e[i].next) if (e[i].to!=fa[x]) dfs1(e[i].to);
if (!cho[x]) {ret(tmp);return;}
now=x;
dfs2(1,0);
ret(tmp);
}
void clear()
{
cnt=0;
for (int i=1;i<=n;i++) ls1[i]=ls2[i]=0;
}
int main()
{
int T=read();
while (T--)
{
n=read();m=read();B=sqrt(n);sum=m;
clear();
for (int i=1;i<=n;i++) g1[i].x=read(),g1[i].y=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read();
add1(x,y);
}
for (int i=1;i<=n;i++) g2[i].x=read(),g2[i].y=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read();
add2(x,y);
}
pre(1);cho[1]=1;
for (int i=1,j;i<=n;i++)
{
for (j=i;!cho[j];j=fa[j]);
bel[i]=j;
}
for (int i=1;i<=m;i++) f[i]=i,s[i]=1;
dfs1(1);
for (int i=1;i<=n;i++) printf("%d\n",ans[i]);
}
return 0;
}