题意:给出一棵n个节点的树,m次询问,找出u和v的距离
思路:每次对u和v找到他们的lca,并且设定一个数据结构,dis[i]表示i到根的距离,那么u和v的距离等于:dp[u] + dp[v] - 2 * dp[lca]
给出在线和离线两种方法
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2586
tarjan离线算法:
#include <cstdio>
#include <cstring>
#include <algorithm>
#pragma warning (disable: 4996)
using namespace std;
const int maxn = 40005;
const int maxm = 205;
struct edge
{
int u, v, w, next;
}e[maxn];
int head[maxn], cnt;
struct quiry
{
int u, v, num, next;
}q[maxm << 1];
int head1[maxn], cnt1;
int n, m;
int p[maxn];
int dis[maxn];
bool in[maxn];
bool vis[maxn];
int root;
int lca[maxm];
int Find(int x)
{
return p[x] == x ? p[x] : p[x] = Find(p[x]);
}
void Union(int x, int y)
{
int dx = Find(x);
int dy = Find(y);
if (dx != dy)
p[dx] = dy;
}
void init()
{
cnt = 0;
cnt1 = 0;
memset(dis, 0, sizeof(dis));
memset(in, false, sizeof(in));
memset(head, -1, sizeof(head));
memset(vis, false, sizeof(vis));
memset(head1, -1, sizeof(head1));
for (int i = 1; i <= n; i++)
p[i] = i;
}
void addedge(int u, int v, int w)
{
e[cnt].u = u, e[cnt].v = v, e[cnt].w = w, e[cnt].next = head[u], head[u] = cnt++;
}
void addquery(int u, int v, int num)
{
q[cnt1].u = u, q[cnt1].v = v, q[cnt1].num = num, q[cnt1].next = head1[u], head1[u] = cnt1++;
}
void tarjan(int u)
{
vis[u] = true;
int v;
for (int i = head1[u]; i != -1; i = q[i].next)
{
v = q[i].v;
if (vis[v])
{
lca[q[i].num] = Find(v);
}
}
for (int i = head[u]; i != -1; i = e[i].next)
{
v = e[i].v;
if (!vis[v])
{
dis[v] = dis[u] + e[i].w;
tarjan(v);
p[v] = u;
}
}
}
int main()
{
int t;
scanf("%d", &t);
while (t--)
{
scanf("%d%d", &n, &m);
init();
int u, v, w, c;
for (int i = 0; i < n - 1; i++)
{
scanf("%d%d%d", &u, &v, &w);
addedge(u, v, w);
in[v] = true;
}
for (int i = 0; i < m; i++)
{
scanf("%d%d", &u, &v);
addquery(u, v, i);
addquery(v, u, i);
}
for (int i = 1; i <= n; i++)
{
if (!in[i])
{
root = i;
break;
}
}
tarjan(root);
for (int i = 0; i < m * 2; i += 2)
{
u = q[i].u, v = q[i].v, c = q[i].num;
printf("%d\n", dis[u] + dis[v] - 2 * dis[lca[c]]);
}
}
return 0;
}
在线算法
#include <cstdio>
#include <cstring>
#include <algorithm>
#pragma warning (disable: 4996)
using namespace std;
const int maxn = 40005;
const int maxm = 205;
struct edge
{
int u, v, w, next;
}e[maxn];
int head[maxn], cnt;
int n, m;
bool in[maxn];
int root;
int vs[maxn << 1];
int dep[maxn << 1];
int id[maxn];
int dp[maxn << 1][20]; //注意这里要乘2
int dis[maxn];
void init()
{
cnt = 0;
memset(dis, 0, sizeof(dis));
memset(in, false, sizeof(in));
memset(head, -1, sizeof(head));
}
void addedge(int u, int v, int w)
{
e[cnt].u = u, e[cnt].v = v, e[cnt].w = w, e[cnt].next = head[u], head[u] = cnt++;
}
void dfs(int u, int fa, int d, int &k)
{
id[u] = k;
vs[k] = u;
dep[k++] = d;
int v;
for (int i = head[u]; i != -1; i = e[i].next)
{
v = e[i].v;
if (v != fa)
{
dis[v] = dis[u] + e[i].w;
dfs(v, u, d + 1, k);
vs[k] = u;
dep[k++] = d;
}
}
}
int Min(int x, int y)
{
return dep[x] <= dep[y] ? x : y;
}
int rmq(int l, int r)
{
int k = 0;
while ((1 << (k + 1)) <= r - l + 1) k++;
return Min(dp[l][k], dp[r - (1 << k) + 1][k]);
}
void rmq_init(int n)
{
for (int i = 1; i <= n; i++)
dp[i][0] = i;
for (int j = 1; (1 << j) <= n; j++)
{
for (int i = 1; i + (1 << j) - 1 < n; i++)
dp[i][j] = Min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);
}
}
void build()
{
for (int i = 1; i <= n; i++)
{
if (!in[i])
{
root = i;
break;
}
}
int k = 0;
dfs(root, -1, 0, k);
rmq_init(n * 2);
}
int lca(int u, int v)
{
int l, r;
l = min(id[u], id[v]);
r = max(id[u], id[v]);
int res = rmq(l, r);
return vs[res];
}
int main()
{
int t;
scanf("%d", &t);
while (t--)
{
init();
scanf("%d%d", &n, &m);
for (int i = 0; i < n - 1; i++)
{
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
addedge(u, v, w);
in[v] = true;
}
build();
for (int i = 0; i < m; i++)
{
int u, v;
scanf("%d%d", &u, &v);
int a = lca(u, v);
printf("%d\n", dis[u] + dis[v] - 2 * dis[a]);
}
}
return 0;
}