题目:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemId=3320
题意:n个点,n-1条边,给出边的端点及权值,Q个询问,三个点的最短距离。
思路:由于n(1~50000),Q(1~700000), 所以采用离线算法Tarjan,O(n+Q),
三个点的最短距离是两点之间的最短距离相加除以2,所以将两个点的最近公共祖先存起来,再进行计算。
接下来便是套模板,询问的端点应该存两次,防止存在到达询问是还有点没有被访问到 的 情况出现。
AC.
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 50005;
const int maxq = 70005;
int far[maxn];
int find(int x)
{
if(far[x] == x) return x;
return far[x] = find(far[x]);
}
void unite(int a, int b)
{
int x = find(a), y = find(b);
far[y] = x;
}
struct Edge {
int to, next, w;
}edge[maxn*2];
int head[maxn], tot;
void addedge(int u, int v, int w)
{
edge[tot].to = v;
edge[tot].w = w;
edge[tot].next = head[u];
head[u] = tot++;
}
short asku[maxq*6], askv[maxq*6];
int askid[maxq*6], asknext[maxq*6];
int h[maxn], tt;
void addask(int u, int v, int index)
{
asku[tt] = u;
askv[tt] = v;
askid[tt] = index;
asknext[tt] = h[u];
h[u] = tt++;
}
bool vis[maxn];
int anc[maxn], dis[maxn];
int res[maxq*6];
void LCA(int u)
{
anc[u] = u;
vis[u] = 1;
for(int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if(!vis[v]) {
dis[v] = dis[u] + edge[i].w;
LCA(v);
unite(u, v);
anc[find(u)] = u;
}
}
for(int i = h[u]; i != -1; i = asknext[i]) {
int v = askv[i];
if(vis[v]) {
res[i] = anc[find(v)];
}
}
}
int n;
void init()
{
tot = 0; tt = 0;
for(int i = 0; i <= n; ++i) {
far[i] = i;
dis[i] = 0;
vis[i] = 0;
}
memset(head, -1, sizeof(head));
memset(h, -1, sizeof(h));
memset(res, -1, sizeof(res));
}
int main()
{
//freopen("in", "r", stdin);
int ca = 1;
while(~scanf("%d", &n)) {
int u, v, r, w;
init();
for(int i = 0; i < n-1; ++i) {
scanf("%d %d %d", &u, &v, &w);
addedge(u, v, w);
addedge(v, u, w);
}
int q;
scanf("%d", &q);
for(int i = 0; i < q; ++i) {
scanf("%d %d %d", &u, &v, &r);
addask(u, v, i);
addask(v, u, i);
addask(u, r, i);
addask(r, u, i);
addask(v, r, i);
addask(r, v, i);
}
LCA(0);
if(ca != 1)printf("\n");
int ans = 0, f;
for(int i = 0; i < tt; i+=6) {
ans = 0;
for(int j = i; j < i+6; j+=2) {
u = asku[j]; v = askv[j];
f = res[j];
if(f == -1) f = res[j+1];
//printf("%d (%d %d)\n", f, u, v);
ans += (dis[u] + dis[v] - 2*dis[f]);
}
printf("%d\n", ans/2);
}
ca++;
}
return 0;
}