题目:http://poj.org/problem?id=1986
题意:N个节点M条边的一棵树,M条边的信息,K个询问,求出两个端点之间的距离。
思路:LCA的在线算法RMQ。
该算法详细解析见http://blog.csdn.net/y990041769/article/details/40887469
AC.
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 40005;
int rmq[2*maxn];
struct ST {
int mm[2*maxn];
int dp[2*maxn][30];
void init(int n) {
mm[0] = -1;
for(int i = 1; i <= n; ++i) {
mm[i] = ((i&(i-1)) == 0) ? mm[i-1]+1: mm[i-1];
dp[i][0] = i;
}
for(int j = 1; j <= mm[n]; ++j) {
for(int i = 1; i+(1<<j) - 1 <= n; i++) {
dp[i][j] = rmq[dp[i][j-1]] < rmq[dp[i+(1<<(j-1))][j-1]]?
dp[i][j-1]: dp[i+(1<<(j-1))][j-1];
}
}
}
int query(int a, int b) {
if(a > b) swap(a, b);
int k = mm[b-a+1];
int p = dp[a][k], q = dp[b-(1<<k)+1][k];
return rmq[p] <= rmq[q]? p: q;
}
};
struct Edge {
int to, next, w;
};
Edge edge[maxn*2];
int tot, head[maxn];
int F[maxn*2], P[maxn], dis[maxn];
bool flag[maxn];
int cnt;
ST st;
void init()
{
tot = 0;
memset(flag, 0,sizeof(flag));
memset(head, -1, sizeof(head));
}
void addedge(int u, int v, int w)
{
edge[tot].w = w;
edge[tot].to = v;
edge[tot].next = head[u];
head[u] = tot++;
}
void dfs(int u, int pre, int dep)
{
F[++cnt] = u;
rmq[cnt] = dep;
P[u] = cnt;
for(int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if(v == pre) continue;
dis[v] = dis[u] + edge[i].w;
dfs(v, u, dep+1);
F[++cnt] = u;
rmq[cnt] = dep;
}
}
void LCA_init(int root, int node_num)
{
cnt = 0;
dis[1] = 0;
dfs(root, root, 0);
st.init(2*node_num-1);
}
int query_lac(int u, int v)
{
return F[st.query(P[u], P[v])];
}
int main()
{
//freopen("in", "r", stdin);
int V, N, u, v;
while(~scanf("%d %d", &V, &N)) {
int u, v, w;
char r;
init();
for(int i = 0; i < N; ++i) {
scanf("%d%d%d %c", &u, &v, &w, &r);
addedge(u, v, w);
addedge(v, u, w);
flag[v] = true;
}
int root;
for(int i = 1; i <= V; ++i) {
if(!flag[i]) {
root = i;
break;
}
}
LCA_init(root, V);
int K;
scanf("%d", &K);
while(K--) {
int u, v;
scanf("%d %d", &u, &v);
int c = query_lac(u, v);
//printf("c = %d\n", c);
//printf("%d %d %d\n", dis[u], dis[v], dis[c]);
printf("%d\n", dis[u]+dis[v] - 2*dis[c]);
}
}
return 0;
}