倍增法:
//DIST a b:求a到b的距离
//KTH a b k:求从a到b路径上的第k个节点,a是第一个
const int N = 10000 + 10, INF = 0x3f3f3f3f;
struct edge
{
int to, cost, next;
}g[N*2];
int cnt, head[N];
int dep[N], dis[N], fat[N][20];
void init()
{
cnt = 0;
memset(head, -1, sizeof head);
}
void add_edge(int v, int u, int cost)
{
g[cnt].to = u, g[cnt].cost = cost, g[cnt].next = head[v], head[v] = cnt++;
}
void dfs(int v, int fa, int d, int val)
{
dep[v] = d, fat[v][0] = fa, dis[v] = val;
for(int i = head[v]; ~i; i = g[i].next)
{
int u = g[i].to;
if(u == fa) continue;
dfs(u, v, d+1, val + g[i].cost);
}
}
void lca_init(int n)
{
for(int j = 1; (1<<j) <= n; j++)
for(int i = 1; i <= n; i++)
fat[i][j] = fat[fat[i][j-1]][j-1];
}
int lca(int v, int u)
{
if(dep[v] < dep[u]) swap(v, u);
int d = dep[v] - dep[u];
for(int i = 0; (d>>i) != 0; i++)
if((d>>i) & 1) v = fat[v][i];
if(v == u) return v;
for(int i = 18; i >= 0; i--)
if(fat[v][i] != fat[u][i]) v = fat[v][i], u = fat[u][i];
return fat[v][0];
}
int query_kth(int v, int u, int k)
{
int t = lca(v, u);
int len = dep[v] - dep[t] + 1;
if(k <= len)
{
--k;
for(int i = 0; (k>>i) != 0; i++)
if((k>>i) & 1) v = fat[v][i];
return v;
}
else
{
k = dep[u] - dep[t] - (k - len);
for(int i = 0; (k>>i) != 0; i++)
if((k>>i) & 1) u = fat[u][i];
return u;
}
}
int main()
{
int t, n;
scanf("%d", &t);
while(t--)
{
init();
scanf("%d", &n);
int a, b, c;
for(int i = 1; i <= n; i++)
{
scanf("%d%d%d", &a, &b, &c);
add_edge(a, b, c); add_edge(b, a, c);
}
dfs(1, 0, 1, 0);
lca_init(n);
char opt[20];
while(scanf("%s", opt), opt[1] != 'O')
{
scanf("%d%d", &a, &b);
if(opt[0] == 'D')
{
int ans = lca(a, b);
printf("%d\n", dis[a] - 2 * dis[ans] + dis[b]);
}
else
{
scanf("%d", &c);
printf("%d\n", query_kth(a, b, c));
}
}
}
return 0;
}
LCA转RMQ:
给定一个无向树,求树上任意两点之间的距离
typedef long long ll;
const int N = 50010;
struct edge
{
int to, cost, next;
}g[N*2];
int cnt, head[N];
int dis[N];
int dp[20][N*2], deg[N];//deg记录点的入度
int tot, dep[N*2], ord[N*2], fir[N];//dep记录访问到访问序列的深度,ord记录访问序列,fir记录某点第一次出现ord中的下标
void init()
{
cnt = 0;
memset(head, -1, sizeof head);
memset(deg, 0, sizeof deg);
tot = 0;//用来给访问序列编号
}
void add_edge(int v, int u, int cost)
{
g[cnt].to = u, g[cnt].cost = cost, g[cnt].next = head[v], head[v] = cnt++;
}
void dfs(int v, int fa, int d, int cost) //可以在dfs的时候多记录一些信息,比如前驱点和前驱边等
{
ord[++tot] = v, dep[tot] = d, fir[v] = tot;
dis[v] = cost;
for(int i = head[v]; i != -1; i = g[i].next)
{
int u = g[i].to;
if(u == fa) continue;
dfs(u, v, d + 1, dis[v] + g[i].cost);
ord[++tot] = v, dep[tot] = d;
}
}
void ST(int n)
{
for(int i = 1; i <= n; i++)
dp[0][i] = i;
for(int i = 1; (1<<i) <= n; i++)
for(int j = 1; j <= n - (1<<i) + 1; j++)
dp[i][j] = dep[dp[i-1][j]] < dep[dp[i-1][j+(1<<(i-1))]] ? dp[i-1][j] : dp[i-1][j+(1<<(i-1))];
}
int RMQ(int l, int r)
{
int k = log(r - l + 1) / log(2.0);
return dep[dp[k][l]] < dep[dp[k][r-(1<<k)+1]] ? dp[k][l] : dp[k][r-(1<<k)+1];
}
int LCA(int v, int u)
{
v = fir[v], u = fir[u];
if(v > u) swap(v, u);
int res = RMQ(v, u);
return ord[res];
}
int main()
{
int t, n, m, a, b, c;
scanf("%d", &t);
while(t--)
{
init();
scanf("%d%d", &n, &m);
for(int i = 1; i <= n-1; i++)
{
scanf("%d%d%d", &a, &b, &c);
add_edge(a, b, c);
add_edge(b, a, c);
}
dfs(1, -1, 1, 0);
ST(2 * n - 1); //注意!!!传的参数是(2*n-1)!!!
for(int i = 1; i <= m; i++)
{
scanf("%d%d", &a, &b);
printf("%d\n", dis[a] + dis[b] - 2*dis[LCA(a, b)]);
}
}
return 0;
}
tarjan算法
#include <bits/stdc++.h>
using namespace std;
const int N = 50000 + 10;
struct edge
{
int to, cost, id, next;
}g[N*5];
int cnt, head[N], qhead[N];
int dis[N];
int par[N];
int ans[N];
bool vis[N];
void init(int n)
{
cnt = 0;
memset(head, -1, sizeof head);
memset(qhead, -1, sizeof qhead);
memset(vis, 0, sizeof vis);
for(int i = 1; i <= n; i++) par[i] = i;
}
void add_edge(int v, int u, int cost)
{
g[cnt].to = u, g[cnt].cost = cost, g[cnt].next = head[v], head[v] = cnt++;
}
void qadd_edge(int v, int u, int id)
{
g[cnt].to = u, g[cnt].id = id, g[cnt].next = qhead[v], qhead[v] = cnt++;
}
int Find(int x)
{
int r = x, i = x, j;
while(par[r] != r) r = par[r];
while(par[i] != r) j = par[i], par[i] = r, i = j;
return r;
}
void tarjan_lca(int v, int d)
{
vis[v] = true, dis[v] = d;
for(int i = head[v]; ~i; i = g[i].next)
{
int u = g[i].to;
if(! vis[u])
{
tarjan_lca(u, d + g[i].cost);
par[u] = v;
}
}
for(int i = qhead[v]; ~i; i = g[i].next)
{
int u = g[i].to;
if(vis[u]) ans[g[i].id] = dis[v] + dis[u] - 2 * dis[Find(u)];
}
}
int main()
{
int t, n, m;
scanf("%d", &t);
while(t--)
{
scanf("%d%d", &n, &m);
init(n);
int a, b, c;
for(int i = 1; i <= n-1; i++)//建图
{
scanf("%d%d%d", &a, &b, &c);
add_edge(a, b, c); add_edge(b, a, c);
}
for(int i = 1; i <= m; i++)//记录查询
{
scanf("%d%d", &a, &b);
qadd_edge(a, b, i); qadd_edge(b, a, i);
}
tarjan_lca(1, 0);
for(int i = 1; i <= m; i++) printf("%d\n", ans[i]);
}
return 0;
}