思路:
按之前的思路每次查询都重新搜,发现TLE了,之后查了一下离线tarjan,是将所有的查询都先输入后,一次性处理,这样就达到了O (n+m)的复杂度
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<map>
#include<algorithm>
using namespace std;
const int maxn= 200005;
int fa[maxn];
int ind[maxn];
int head[maxn];
int qhead[maxn];
int vis[maxn];
int ced[maxn];
struct node
{
int to;
int w;
int next;
int lca;
int num;
} ;
struct query
{
int u,v,lca;
} ques[maxn];
node edge[maxn];
node qedge[maxn];
int n,m,cnt,cnt1,cnt2;
map<string,int> mp;
int get_num(string s)
{
if(!mp.count(s))
{
mp[s]=++cnt;
}
return mp[s];
}
void add(int u,int v,int w)
{
edge[cnt1].w=w,edge[cnt1].to=v,edge[cnt1].next=head[u];
head[u]=cnt1++;
edge[cnt1].w=w,edge[cnt1].to=u,edge[cnt1].next=head[v];
head[v]=cnt1++;
}
void addq(int u,int v,int w)
{
qedge[cnt2].num=w,qedge[cnt2].to=v,qedge[cnt2].next=qhead[u];
qhead[u]=cnt2++;
qedge[cnt2].num=w,qedge[cnt2].to=u,qedge[cnt2].next=qhead[v];
qhead[v]=cnt2++;
}
void dfs(int u,int fa,int w)
{
ced[u]=w;
for(int i=head[u]; i!=-1; i=edge[i].next)
{
int v=edge[i].to;
if(v==fa) continue;
dfs(v,u,edge[i].w+w);
}
}
int find(int x)
{
if(fa[x]!=x)
fa[x]=find(fa[x]);
return fa[x];
}
void Tarjan_Lca(int u)
{
fa[u]=u;
vis[u]=1;
for(int i=head[u]; i!=-1; i=edge[i].next)
{
if(!vis[edge[i].to])
{
Tarjan_Lca(edge[i].to);
fa[edge[i].to]=u;
}
}
for(int i=qhead[u]; i!=-1; i=qedge[i].next)
{
if(vis[qedge[i].to])
{
qedge[i].lca=find(qedge[i].to);
qedge[i^1].lca=qedge[i].lca;
ques[qedge[i].num].lca=qedge[i].lca;
}
}
}
void Solve()
{
for(int i=0; i<=n; i++)
{
fa[i]=i;
}
memset(head,-1,sizeof(head));
memset(qhead,-1,sizeof(qhead));
memset(vis,0,sizeof(vis));
memset(ind,0,sizeof(ind));
cnt=cnt1=cnt2=0;
int u,v,w;
string s1,s2;
mp.clear();
for(int i=1; i<n; i++)
{
cin>>s1>>s2;
u=get_num(s1);
v=get_num(s2);
add(u,v,1);
ind[u]++;
}
for(int i=0; i<m; i++)
{
cin>>s1>>s2;
u=get_num(s1);
v=get_num(s2);
addq(u,v,i);
ques[i].u=u;
ques[i].v=v;
}
int root=0;
for (int i=1; i<=n; i++)
{
if(ind[i]==0)
{
root=i;
}
}
dfs(root,-1,0);
Tarjan_Lca(root);
}
int main()
{
int T;
cin>>T;
while(T--)
{
cin>>n>>m;
Solve();
for (int i=0; i<m; i++)
{
int ans=0;
ans=ced[ques[i].u]-ced[ques[i].lca];
if(ques[i].lca!=ques[i].v)
ans+=1;
if(ques[i].u==ques[i].v)
ans=0;
cout<<ans<<endl;
}
}
return 0;
}