解决LCA问题主要有在线算法和离线算法。在线算法(问一次答一次):DFS+ST(RMQ)/倍增;离线算法(全部问完了再回答):Tarjan+并查集。先附上讲解博客Orz:(模板是上海大学kuangbin的模板)
在线算法:https://blog.csdn.net/u013076044/article/details/41870751
https://blog.csdn.net/lw277232240/article/details/72870644
模板题:POJ1330
离线算法:https://www.cnblogs.com/JVxie/p/4854719.html
模板题:POJ1470
这道题据说是个板子题。。后来想了下,需要加一个dis数组,用来记录根节点到该结点的路径长度;最后答案=dis[u]+dis[v]-2*dis[LCA(u,v)]。
1.DFS+ST 在线算法:
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<stack>
#include<queue>
using namespace std;
#define ll long long
typedef pair<int,int>pp;
#define mkp make_pair
#define pb push_back
const int INF=0x3f3f3f3f;
const int MAX=40005*2;
int rmq[MAX];
struct ST
{
int dp[MAX][20];
void init(int n)
{
for(int i=1;i<=n;i++)
dp[i][0]=i;
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i+(1<<j)-1<=n;i++)
dp[i][j]=rmq[dp[i][j-1]]<rmq[dp[i+(1<<(j-1))][j-1]]?dp[i][j-1]:dp[i+(1<<(j-1))][j-1];
}
int query(int x,int y)
{
if(x>y)
swap(x,y);
int k=(int)(log(y-x+1.0)/log(2.0));
return rmq[dp[x][k]]<=rmq[dp[y-(1<<k)+1][k]]?dp[x][k]:dp[y-(1<<k)+1][k];
}
};
struct Edge
{
int to,next;
int w;
}edge[MAX];
int tot,head[MAX];
int f[MAX];
int p[MAX];
int cnt;
int dis[MAX];//记录根节点到该节点的路径
ST st;
void init()
{
tot=0;
memset(head,-1,sizeof(head));
memset(dis,0,sizeof(dis));
}
void addedge(int u,int v,int w)
{
edge[tot].to=v;edge[tot].next=head[u];edge[tot].w=w;
head[u]=tot++;
}
void dfs(int u,int pre,int dep)
{
f[++cnt]=u;
rmq[cnt]=dep;
p[u]=cnt;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].to;
if(v==pre)
continue;
dis[v]=dis[u]+edge[i].w;
dfs(v,u,dep+1);
f[++cnt]=u;
rmq[cnt]=dep;
}
}
void lca_init(int root,int node_num)
{
cnt=0;
dis[root]=0;
dfs(root,root,0);
st.init(2*node_num-1);
}
int lca_query(int u,int v)
{
return f[st.query(p[u],p[v])];
}
bool vis[MAX];
int main()
{
int t;
int n,q;
int u,v,w;
scanf("%d",&t);
while(t--)
{
scanf("%d%d",&n,&q);
init();
memset(vis,false,sizeof(vis));
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&w);
addedge(u,v,w);
addedge(v,u,w);
vis[v]=true;
}
int root;
for(int i=1;i<=n;i++)
if(!vis[i])
{
root=i;
break;
}
lca_init(root,n);
while(q--)
{
scanf("%d%d",&u,&v);
int ro=lca_query(u,v);
int ans=dis[u]+dis[v]-2*dis[ro];
printf("%d\n",ans);
}
}
return 0;
}
2.Tarjan+并查集 离线算法
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<stack>
#include<queue>
using namespace std;
#define ll long long
typedef pair<int,int>pp;
#define mkp make_pair
#define pb push_back
const double pi=acos(-1.0);
const double eps=1e-9;
const int INF=0x3f3f3f3f;
const ll MOD=1e9+(ll)7;
const int MAX=40005;
int n;
int dis[MAX];//到树根的距离
int f[MAX];
int Find(int x)
{
if(f[x]==x)
return x;
else
return Find(f[x]);
}
int Union(int u,int v)
{
int u0=Find(u);
int v0=Find(v);
if(u0!=v0)
f[v0]=u0;
}
bool vis[MAX];
int an[MAX];//祖先
struct edge
{
int to,next;
int w;//边的权值
}edge[MAX*3];//注意范围!!
int head[MAX],tot;
void addedge(int u,int v,int w)
{
edge[tot].to=v;edge[tot].next=head[u];edge[tot].w=w;
head[u]=tot++;
}
const int MAXQ=40005;
struct query
{
int q,next;
int index;//查询编号
}que[MAXQ*3];
int ans[MAXQ];//下标0~Q-1
int hea[MAXQ];
int tt;
int Q;//询问次数
void add_query(int u,int v,int index)
{
que[tt].q=v;que[tt].next=hea[u];que[tt].index=index;hea[u]=tt++;
que[tt].q=u;que[tt].next=hea[v];que[tt].index=index;hea[v]=tt++;
}
void init(int n)
{
tot=0;
memset(head,-1,sizeof(head));
tt=0;
memset(hea,-1,sizeof(hea));
memset(vis,false,sizeof(vis));
memset(an,0,sizeof(an));
for(int i=0;i<=n;i++)
f[i]=i;
memset(dis,0,sizeof(dis));
}
void lca(int u)
{
an[u]=u;
vis[u]=true;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].to;
if(vis[v])
continue;
dis[v]=dis[u]+edge[i].w;
lca(v);
Union(u,v);
an[Find(u)]=u;
}
for(int i=hea[u];i!=-1;i=que[i].next)
{
int v=que[i].q;
if(vis[v])
{
ans[que[i].index]=an[Find(v)];
}
}
}
bool flag[MAX];
int tmp[MAXQ][2];//记录询问的顶点
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
scanf("%d%d",&n,&Q);
init(n);
int u,v,w;
memset(flag,false,sizeof(flag));
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&w);
addedge(u,v,w);
addedge(v,u,w);
flag[v]=true;
}
for(int i=0;i<Q;i++)
{
scanf("%d%d",&u,&v);
add_query(u,v,i);
tmp[i][0]=u,tmp[i][1]=v;
}
int root;
for(int i=1;i<=n;i++)
if(!flag[i])
{
root=i;
break;
}
dis[root]=0;
lca(root);
for(int i=0;i<Q;i++)
{
//cout<<"i="<<i<<" ans="<<ans[i]<<endl;
int aans=dis[tmp[i][0]]+dis[tmp[i][1]]-2*dis[ans[i]];
printf("%d\n",aans);
}
}
return 0;
}