You are given a tree (an undirected acyclic connected graph) with N nodes, and edges numbered 1, 2, 3...N-1. Each edge has an integer value assigned to it, representing its length.
We will ask you to perfrom some instructions of the following form:
DIST a b : ask for the distance between node a and node b
or
KTH a b k : ask for the k-th node on the path from node a to node b
Example:
N = 6
1 2 1 // edge connects node 1 and node 2 has cost 1
2 4 1
2 5 2
1 3 1
3 6 2
Path from node 4 to node 6 is 4 -> 2 -> 1 -> 3 -> 6
DIST 4 6 : answer is 5 (1 + 1 + 1 + 2 = 5)
KTH 4 6 4 : answer is 3 (the 4-th node on the path from node 4 to node 6 is 3)
Input
The first line of input contains an integer t, the number of test cases (t <= 25). t test cases follow.
For each test case:
In the first line there is an integer N (N <= 10000)
In the next N-1 lines, the i-th line describes the i-th edge: a line with three integers a b c denotes an edge between a, b of cost c (c <= 100000)
The next lines contain instructions "DIST a b" or "KTH a b k"
The end of each test case is signified by the string "DONE".
There is one blank line between successive tests.
Output
For each "DIST" or "KTH" operation, write one integer representing its result.
Print one blank line after each test.
Example
Input:
1
6
1 2 1
2 4 1
2 5 2
1 3 1
3 6 2
DIST 4 6
KTH 4 6 4
DONE
Output:
5
3
题意:
在树上查询两点之间距离以及路径上第k个点的编号
倍增+lca
主要是判断第k个点在u-lca上还是lca-v上
#include<iostream>
#include<cstdio>
#include<queue>
#include<cstring>
#include<algorithm>
using namespace std;
const long long maxn=100000+100;
const long long DEG=20;
struct edge
{
long long v,nxt,w;
}edge[maxn*2+100];
long long head[maxn],cnt=0;
long long dep[maxn];
void add_edge(long long u,long long v,long long w)
{
edge[cnt].v=v;
edge[cnt].w=w;
edge[cnt].nxt=head[u];
head[u]=cnt++;
}
long long fa[maxn][29];
long long deg[maxn];
void bfs(long long u)
{
queue<long long>q;
q.push(u);
deg[u]=0;
dep[u]=0;
fa[u][0]=u;
while(!q.empty())
{
long long now=q.front();
q.pop();
for(long long i=1;i<DEG;i++)
{
fa[now][i]=fa[fa[now][i-1]][i-1];
}
for(long long i=head[now];i!=-1;i=edge[i].nxt)
{
long long v=edge[i].v;
if(v==fa[now][0])
continue;
deg[v]=deg[now]+edge[i].w;
dep[v]=dep[now]+1;
fa[v][0]=now;
q.push(v);
}
}
}
long long lca(long long u,long long v)
{
if(dep[u]>dep[v])
swap(u,v);
long long hu=dep[u],hv=dep[v];
long long tu=u,tv=v;
for(long long det=hv-hu,i=0;det;det>>=1,i++)
{
if(det&1)
tv=fa[tv][i];
}
if(tu==tv)
return tu;
for(long long i=DEG-1;i>=0;i--)
{
if(fa[tu][i]==fa[tv][i])
continue;
tu=fa[tu][i];
tv=fa[tv][i];
}
return fa[tu][0];
}
long long getdis(long long xx,long long yy)
{
return deg[xx]+deg[yy]-2*deg[lca(xx,yy)];
}
long long getk(long long u,long long k)
{
for(long long det=k,i=0;det;det>>=1,i++)
{
if(det&1)
u=fa[u][i];
}
return u;
}
long long flag[maxn];
int main ()
{
long long t;
scanf("%lld",&t);
while(t--)
{
long long m;
memset(head,-1,sizeof(head));
memset(flag,0,sizeof(flag));
memset(deg,0,sizeof(deg));
memset(dep,0,sizeof(dep));
memset(fa,0,sizeof(fa));
cnt=0;
cin>>m;
for(long long i=1;i<m;i++)
{
long long xx,yy,zz;
cin>>xx>>yy>>zz;
add_edge(xx,yy,zz);
}
bfs(1);
char op[20];
while(~scanf("%s",op))
{
if(op[1]=='O')
{
break;
}
if(op[1]=='T')
{
long long u,v,k;
scanf("%lld%lld%lld",&u,&v,&k);
long long p=lca(u,v);
long long x = dep[u] - dep[p];
if(x + 1 >= k)
printf("%lld\n", getk(u, k - 1));
else
printf("%lld\n", getk(v, dep[v] + dep[u] - 2 * dep[p] + 1 - k));
}
else
{
long long u,v;
scanf("%lld%lld",&u,&v);
printf("%lld\n",getdis(u,v));
}
}
}
}
/*
1
6
1 2 1
2 4 1
2 5 2
1 3 1
3 6 2
DIST 4 6
KTH 4 6 4
DONE
*/