学习了lca的ST算法。时间复杂度为O(mlogn)
参考博客:
http://blog.csdn.net/y990041769/article/details/40887469
http://blog.csdn.net/liangzhaoyang1/article/details/52549822
树上最短路公式:
记dis[u]为根节点到u节点的距离。
dist(u,v) = dis[u] + dis[v] - 2 * dis[lca(v, v)]
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 40010;
const int maxe = 20;
int n,m;
vector<int>G[maxn];
vector<int>W[maxn];
int tot;
int nod[maxn<<1];
int dep[maxn<<1];
int fir[maxn];
ll dfn[maxn];
int dp[maxn<<2][maxe];
int pp[maxn<<2][maxe];
void dfs(int u,int f,int d)
{
nod[++tot]=u;
dep[tot]=d;
fir[u]=tot;
for(int i=0;i<(int)G[u].size();i++)
{
int v = G[u][i];
if(v==f) continue;
dfn[v]=dfn[u]+W[u][i];
dfs(v,u,d+1);
nod[++tot]=u;
dep[tot]=d;
}
}
void ST()
{
for(int i=1;i<=tot;i++)
{
dp[i][0]=dep[i];
pp[i][0]=i;
}
for(int j=1;j<maxe;j++)
for(int i=1;i+(1<<j)-1<=tot;i++)
{
if(dp[i][j-1]<dp[i+(1<<(j-1))][j-1])
{
dp[i][j]=dp[i][j-1];
pp[i][j]=pp[i][j-1];
}
else
{
dp[i][j]=dp[i+(1<<(j-1))][j-1];
pp[i][j]=pp[i+(1<<(j-1))][j-1];
}
}
}
int lca(int u,int v)
{
int l = fir[u];
int r = fir[v];
if(l>r) swap(l,r);
int k = 0;
while(1<<(k+1)<=r-l+1) k++;
if(dp[l][k]<dp[r-(1<<k)+1][k]) return nod[pp[l][k]];
else return nod[pp[r-(1<<k)+1][k]];
}
void solve()
{
scanf("%d %d",&n,&m);
for(int i=1;i<=n;i++)
{
G[i].clear();
W[i].clear();
}
for(int i=1;i<n;i++)
{
int u,v,w;
scanf("%d %d %d",&u,&v,&w);
G[u].push_back(v);
W[u].push_back(w);
G[v].push_back(u);
W[v].push_back(w);
}
dfs(1,0,1);
ST();
for(int i=1;i<=m;i++)
{
int u,v;
scanf("%d %d",&u,&v);
printf("%lld\n",dfn[u]+dfn[v]-2*dfn[lca(u,v)]);
}
}
int main()
{
int T;
scanf("%d",&T);
while(T--) solve();
return 0;
}