题意:给定一棵树, 然后加一条边, 有若干询问, 问你每一个询问(u,v), 加了这条边后可以从u到v节省多少距离。
思路: 一共三种情况, 1, 原路
2,u - x - y - v
3,u - y - x - v
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100010;
struct node{
int to, w, next;
}edge[MAXN*2];
int tot,head[MAXN];
void init(){
tot = 0;memset(head, -1, sizeof(head));
}
void add_edge(int u, int v, int w){
edge[tot].to = v;
edge[tot].w = w;
edge[tot].next = head[u];
head[u] = tot++;
}
//LCA部分
int rmq[2*MAXN];//rmq数组,就是欧拉序列对应的深度序列
struct ST
{
int mm[2*MAXN];
int dp[2*MAXN][20];//最小值对应的下标
void init(int n)
{
mm[0] = -1;
for(int i = 1;i <= n;i++)
{
mm[i] = ((i&(i-1)) == 0)?mm[i-1]+1:mm[i-1];
dp[i][0] = i;
}
for(int j = 1; j <= mm[n];j++)
for(int i = 1; i + (1<<j) - 1 <= n; i++)
dp[i][j] = rmq[dp[i][j-1]] < rmq[dp[i+(1<<(j-1))][j-1]]?dp[i][j-1]:dp[i+(1<<(j-1))][j-1];
}
int query(int a,int b)//查询[a,b]之间最小值的下标
{
if(a > b)swap(a,b);
int k = mm[b-a+1];
return rmq[dp[a][k]] <= rmq[dp[b-(1<<k)+1][k]]?dp[a][k]:dp[b-(1<<k)+1][k];
}
};
int F[MAXN*2];//欧拉序列,就是dfs遍历的顺序,长度为2*n-1,下标从1开始
int P[MAXN];//P[i]表示点i在F中第一次出现的位置
int cnt;
int dis[MAXN];
ST st;
void dfs(int u,int pre,int dep, int d)
{
dis[u] = d;
F[++cnt] = u;
rmq[cnt] = dep;
P[u] = cnt;
for(int i = head[u];i != -1;i = edge[i].next)
{
int v = edge[i].to;
if(v == pre)continue;
dfs(v,u,dep+1, edge[i].w+d);
F[++cnt] = u;
rmq[cnt] = dep;
}
}
void LCA_init(int root,int node_num)//查询LCA前的初始化
{
cnt = 0;
dfs(root,root,0, 0);
st.init(2*node_num-1);
}
int query_lca(int u,int v)//查询u,v的lca编号
{
return F[st.query(P[u],P[v])];
}
int calc(int u, int v){
int LCA = query_lca(u, v);
return dis[u] + dis[v] - 2*dis[LCA];
}
int main(){
int T;
cin>>T;
int icase = 0;
while(T--){
init();
printf("Case #%d:\n", ++icase);
int n,q;
scanf("%d %d", &n, &q);
for(int i=1; i<n; i++){
int u,v,w;
scanf("%d %d %d", &u, &v, &w);
add_edge(u, v, w), add_edge(v, u, w);
}
LCA_init(1, n);
int u,v,w;
int lca = query_lca(u, v);
scanf("%d %d %d", &u, &v, &w);
while(q--){
int x,y;
scanf("%d %d", &x, &y);
int sum1 = calc(x, y);
int sum2 = calc(x, u) + w + calc(y, v);
int sum3 = calc(y, u) + w + calc(x, v);
if(sum2 >= sum1 && sum3>=sum1) printf("0\n");
else printf("%d\n", sum1 - min(sum2, sum3));
}
}
return 0;
}