找到树的直径,然后分别从直径两个端点建立倍增数组实现查找。
#include <bits/stdc++.h>
#define pb push_back
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
typedef long long LL;
using namespace std;
const int maxn = 20010;
vector<int>G[maxn];
int dp[maxn][20][2];
int pre[maxn][20][2];
int dis[maxn][20][2];
bool vis[maxn];
int l,r,n;
void bfs()
{
queue<pair<int,int> >Q;
Q.push( make_pair(1,0) );
memset(vis,0,sizeof(vis));
vis[1] = 1;
int num = 0,ans;
while(!Q.empty())
{
int a = Q.front().first;
int b = Q.front().second;
if(b>num){
num = b;
ans = a;
}
Q.pop();
for(int i=0;i<G[a].size();i++)
{
int v = G[a][i];
if(!vis[v])
{
vis[v] = 1;
Q.push(make_pair(v,b+1));
}
}
}
l = ans;
memset(vis,0,sizeof(vis));
vis[l] = 1;
Q.push(make_pair(l,0));
num = 0;
while(!Q.empty())
{
int a = Q.front().first;
int b = Q.front().second;
if(b>num){
num = b;
ans = a;
}
Q.pop();
for(int i=0;i<G[a].size();i++)
{
int v = G[a][i];
if(!vis[v])
{
vis[v] = 1;
Q.push(make_pair(v,b+1));
}
}
}
r = ans;
}
void dfs(int x,int fa,int de,int flag)
{
vis[x] = 1;
pre[x][0][flag] = fa;
for(int i=0;i<G[x].size();i++)
if(G[x][i]!=fa)
{
dis[ G[x][i] ][0][flag] = 1;
dfs(G[x][i],x,de+1,flag);
}
}
void init()
{
memset(pre,-1,sizeof(pre));
memset(dis,0,sizeof(dis));
memset(vis,0,sizeof(vis));
dfs(l,l,1,0);
for(int k=1; k<20; ++k)
for(int i=1; i<=n; ++i)
if(pre[i][k-1][0]!=-1)
{
pre[i][k][0] = pre[ pre[i][k-1][0] ][k-1][0];
dis[i][k][0] = dis[i][k-1][0] + dis[ pre[i][k-1][0] ][k-1][0];
}
memset(vis,0,sizeof(vis));
dfs(r,r,1,1);
for(int k=1; k<20; ++k)
for(int i=1; i<=n; ++i)
if(pre[i][k-1][1]!=-1)
{
pre[i][k][1] = pre[ pre[i][k-1][1] ][k-1][1];
dis[i][k][1] = dis[i][k-1][1] + dis[ pre[i][k-1][1] ][k-1][1];
}
}
void solve(int x,int d)
{
if(d==0)
{
printf("%d\n",x);
return ;
}
int ansd = d,y=x;
while(1)
{
int j = -1;
for(int k=0;k<20;k++)
if(dis[y][k][0]>=ansd)
{
j = k;
break;
}
if(j==-1) break;
if(dis[y][j][0]==ansd)
{
printf("%d\n",pre[y][j][0]);
return ;
}
else {
ansd -= dis[y][j-1][0];
y = pre[y][j-1][0];
}
}
ansd = d,y=x;
while(1)
{
int j = -1;
for(int k=0;k<20;k++)
if(dis[y][k][1]>=ansd)
{
j = k;
break;
}
if(j==-1) break;
if(dis[y][j][1]==ansd)
{
printf("%d\n",pre[y][j][1]);
return ;
}
else {
ansd -= dis[y][j-1][1];
y = pre[y][j-1][1];
}
}
printf("0\n");
}
int main()
{
int Q,x,y,val,d;
while(scanf("%d%d",&n,&Q)!=EOF)
{
for(int i=0;i<=n;i++) G[i].clear();
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
G[x].pb(y);
G[y].pb(x);
}
bfs();
init();
while(Q--)
{
scanf("%d%d",&val,&d);
solve(val,d);
}
}
return 0;
}