tanjar算法离线求LCA的思想主要是利用并查集的思想。
求距离的话就是d[start[i]]+end[en[i]]-2*d[lca[i]];
首先从根节点dfs,在深度遍历的回溯的过程中不断的更新自己的父节点,使得查询两个点肯定是存在一颗最近的子树里边,所以
find(v)就是LCA。
void dfs(int u)
{
f[u]=u;//首先把该点的父亲赋值为自己,即记录父亲又当vis数组
for(int i=head[u];~i;i=edge[i].p)
{
int v=edge[i].to;
if(!f[v]) ///f[]表示父亲,一开始赋为0,只要遍历过就肯定不为0 (init里)
d[v]=d[u]+edge[i].val;//记录到根节点的距离
dfs(v);///如果没遍历过就继续跑
f[v]=u;///重点,思想就是上边一步遍历v时f[v]还是自己,这个遍历完了之后才把父亲u赋值给v,
/// 当下边存在x一个询问到v,v和x一直都会是在一棵树下,并且find(v)的值就是最v,x的
///最近的父节点,也就是lca
}
for(int i=qhead[u];~i;i=qedge[i].p)//询问是否有边一样
{
int v=qedge[i].to;
if(f[v]) lca[i/2+1]=find(v);///因为加入边的时候加入了两次,并且从0开始加,所以是i/2+1。
}
}
AC代码,是要比倍增快一点。
#include<stdio.h>
#include<string.h>
#include<iostream>
#include<algorithm>
#include<math.h>
#include<set>
#include<stack>
#include<vector>
#include<map>
#include<queue>
#define myself i,l,r
#define lson i<<1
#define rson i<<1|1
#define Lson i<<1,l,mid
#define Rson i<<1|1,mid+1,r
#define half (l+r)/2
#define inff 0x3f3f3f3f
#define lowbit(x) x&(-x)
#define PI 3.14159265358979323846
#define me(a,b) memset(a,b,sizeof(a))
#define min4(a,b,c,d) min(min(a,b),min(c,d))
#define min3(x,y,z) min(min(x,y),min(y,z))
const int dir[4][2]= {0,-1,-1,0,0,1,1,0};
typedef long long ll;
const ll inFF=9223372036854775807;
typedef unsigned long long ull;
using namespace std;
const int maxn=4e5+6;
int head[maxn],qhead[maxn],f[maxn],lca[maxn],d[maxn],st[maxn],en[maxn];
int sign,qsign,n,m;
struct node
{
int to,p,val;
}edge[maxn],qedge[maxn];
int find(int x)
{
return x==f[x]?x:f[x]=find(f[x]);
}
void init()
{
sign=qsign=d[1]=0;
for(int i=0;i<=2*n;i++)
{
head[i]=qhead[i]=-1;
f[i]=0;
}
}
void add(int u,int v,int val)
{
edge[sign]=node{v,head[u],val};
head[u]=sign++;
}
void qadd(int u,int v)
{
qedge[qsign]=node{v,qhead[u]};
qhead[u]=qsign++;
}
void dfs(int u)
{
f[u]=u;
for(int i=head[u];~i;i=edge[i].p)
{
int v=edge[i].to;
if(!f[v]) d[v]=d[u]+edge[i].val,dfs(v),f[v]=u;
}
for(int i=qhead[u];~i;i=qedge[i].p)
{
int v=qedge[i].to;
if(f[v]) lca[i/2+1]=find(v);
}
}
int main()
{
int t,x,y,z;
cin>>t;
while(t--)
{
cin>>n>>m;
init();
for(int i=1;i<n;i++) scanf("%d %d %d",&x,&y,&z),add(x,y,z),add(y,x,z);
for(int i=1;i<=m;i++) scanf("%d %d",&st[i],&en[i]),qadd(st[i],en[i]),qadd(en[i],st[i]);
dfs(1);
for(int i=1;i<=m;i++) printf("%d\n",d[st[i]]+d[en[i]]-2*d[lca[i]]);
}
return 0;
}