思路
~~~~~ 先将询问离线,这样就可以在点分治的时候一次性处理。
~~~~~ 设当前遍历到的重心是 p p p ,我们要查询所有 经过 p p p 的路径。那么对于 p p p 的子树内的一个点 x x x,它需要加上的贡献就是距离 p p p 为 k − d e p [ x ] k-dep[x] k−dep[x] 的点的数量, k k k 就是题面中的 k k k,注意实际操作的时候要先把 x x x 的子树去掉。而对于 p p p 本身,它要加上的贡献自然就是距离 p p p 为 k k k 的点的数量。
代码
#include<bits/stdc++.h>
#include<bits/stdc++.h>
#define ll long long
#define fir first
#define sec second
using namespace std;
inline ll read(){
char c=getchar();ll res=0;
while(c<'0'||c>'9')c=getchar();
while(c>='0'&&c<='9'){
res=(res<<1)+(res<<3)+(c^48);
c=getchar();
}
return res;
}
inline void write(ll k){
if(k>=10)write(k/10);
putchar((k%10)|0x30);
}
pair<ll,ll>ask[100005];
vector<ll>eg[100005],wit[100005],qwe;
ll dis[100005],dit,dge[100005],ans[100005];
ll n,m,siz[100005],sum,mxs,root,dep[100005];
ll vis[100005],stk[100005],top,init[100005];
void getroot(ll fa,ll p){
siz[p]=1;ll s=0;
for(ll v:eg[p]){
if(v==fa||vis[v])continue;
getroot(p,v);
siz[p]+=siz[v];
s=max(s,siz[v]);
}
s=max(s,sum-s);
if(s<mxs)mxs=s,root=p;
}
void getdep(ll fa,ll p){
dep[p]=dep[fa]+1;
for(ll v:eg[p]){
if(v==fa||vis[v])continue;
getdep(p,v);
}
}
void getdis(ll fa,ll p,ll now,ll opt){
if(opt){
for(ll it:wit[p])
qwe.push_back(it);
}
dis[++dit]=now;
for(ll v:eg[p]){
if(v==fa||vis[v])continue;
getdis(p,v,now+1,opt);
}
}
void getans(ll p){
vis[p]=1;top=0;
dge[0]=1;getdep(0,p);
for(ll v:eg[p]){
if(vis[v])continue;
dit=0;getdis(p,v,1,0);
for(ll i=1;i<=dit;i++){
dge[dis[i]]++;
if(!init[dis[i]]){
init[dis[i]]=1;
stk[++top]=dis[i];
}
}
}
for(ll it:wit[p])
ans[it]+=dge[ask[it].sec];
for(ll v:eg[p]){
if(vis[v])continue;
qwe.clear();
dit=0;getdis(p,v,1,1);
for(ll i=1;i<=dit;i++)
dge[dis[i]]--;
for(ll it:qwe)
if(ask[it].sec-dep[ask[it].fir]+1>=0)
ans[it]+=dge[ask[it].sec-dep[ask[it].fir]+1];
for(ll i=1;i<=dit;i++)
dge[dis[i]]++;
}
for(ll i=1;i<=top;i++)dge[stk[i]]=init[stk[i]]=0;
}
void divide(ll p){
getans(p);
for(ll v:eg[p]){
if(vis[v])continue;
root=v;
sum=mxs=siz[v];
getroot(p,v);
divide(root);
}
}
void work(){
cin>>n>>m;
memset(ans,0,sizeof(ans));
memset(siz,0,sizeof(siz));
memset(vis,0,sizeof(vis));
for(ll i=1;i<=n;i++)
eg[i].clear(),wit[i].clear();
for(ll i=1;i<n;i++){
ll x=read(),y=read();
eg[x].push_back(y);
eg[y].push_back(x);
}
for(ll i=1;i<=m;i++){
ask[i].fir=read(),ask[i].sec=read();
wit[ask[i].fir].push_back(i);
}
divide(1);
for(ll i=1;i<=m;i++)write(ans[i]),putchar('\n');
}
signed main(){
ll T;cin>>T;
while(T--)work();
return 0;
}
~~~~~ 引用请附名 —— —— —— OMG_NOIP