还是普通的LCA但是要求的是两个节点之间的距离,学到了一些
一开始我想用带权并查集进行优化,但是LCA合并的过程晚于离线计算的过程,所以路径长度会有所偏差
所以失败告终
网上查询之后懂得要提前进行一下预处理,在输入完全部的边之后,也就是数形成之后,计算dis——》也就是每个点到树根的长度
之后进行询问查询时:u,v 和 rt 这样uv的距离就是dis[u] + dis[v] - 2 * dis[rt]很好理解
时间复杂度也还可以
#include <iostream>
#include <cstdio>
#include <string.h>
using namespace std;
const int maxn = 5e4 + 50;
const int maxm = 7e4 + 6e3;
int id[maxn],qid[maxn];
int cnt,qcnt;
int pre[maxn],cost[maxn];
int vis[maxn];
struct node{
int to,pre,cost;
}e[maxn * 2];
struct node2{
int to,ads,pre;
}q[maxm * 2];
int ans[maxm];
int Find(int x)
{
//cout<<x<<endl;
if(pre[x] == x)return x;
else
{
//cost[x] += cost[pre[x]];
return pre[x] = Find(pre[x]);
}
}
void join(int a,int b)
{
int fa = Find(a),fb = Find(b);
if(fa != fb)
{
pre[fb] = fa;
}
}
void init(int n)
{
for(int i = 0;i <= n;i++)
{
pre[i] = i;
vis[i] = 0;
cost[i] = 0;
}
memset(id,-1,sizeof(id));
memset(qid,-1,sizeof(qid));
cnt = qcnt = 0;
}
void add(int from,int to,int cost)
{
e[cnt].to = to;
e[cnt].pre = id[from];
e[cnt].cost = cost;
id[from] = cnt++;
}
void qadd(int from,int to,int i)
{
q[qcnt].to = to;
q[qcnt].ads = i;
q[qcnt].pre = qid[from];
qid[from] = qcnt++;
}
void get_cost(int rt,int dis)
{
vis[rt] = 1;
cost[rt] = dis;
for(int i = id[rt];~i;i = e[i].pre)
{
int to = e[i].to;
int cos = e[i].cost;
if(!vis[to])
{
get_cost(to,dis+cos);
}
}
}
void tarjan(int rt)
{
//cout<<rt<<endl;
vis[rt] = -1;
for(int i = id[rt];~i;i = e[i].pre)
{
int to = e[i].to;
int cos = e[i].cost;
if(!vis[to])
{
tarjan(to);
join(rt,to);
}
//cout<<"cs"<<to<<" "<<cost[to]<<endl;
}
vis[rt] = 1;
for(int i = qid[rt];~i;i = q[i].pre)
{
int to = q[i].to;
if(vis[to] == 1)
{
//cout<<"to : "<<to<<endl;
//cout<<"pre[to]: "<<pre[to]<<endl;
//cout<<rt<<" "<<to<<"cost: "<<cost[rt]<<" "<<cost[to]<<endl;
int lca = Find(to);
ans[q[i].ads] = abs(cost[lca] - cost[rt]) + abs(cost[lca] - cost[to]);
}
}
}
int main()
{
int n,m,u,v,w;
scanf("%d",&n);
init(n);
for(int i = 1;i < n;++i)
{
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
get_cost(0,0);
memset(vis,0,sizeof(vis));
scanf("%d",&m);
for(int i = 1;i <= m;++i)
{
scanf("%d%d",&u,&v);
qadd(u,v,i);
qadd(v,u,i);
}
tarjan(0);
for(int i = 1;i <= m;++i)
{
printf("%d\n",ans[i]);
}
return 0;
}
/*
7
0 1 1
0 2 1
0 3 1
1 4 1
2 5 1
2 6 1
40
0 1
0 2
0 3
0 4
0 5
0 6
0 0
1 1
1 2
1 3
1 4
1 5
1 6
1 0
2 0
2 1
2 2
2 3
2 4
2 5
2 6
3 0
3 1
3 2
3 3
3 4
3 5
3 6
4 0
4 1
4 2
4 3
4 4
4 5
4 6
5 0
5 1
5 2
5 3
5 4
*/