题目传送门
题意:
给你 个节点的森林,边权都是 , 个询问。
询问有多少个节点的距离是 的祖先和第 个节点的距离是 的祖先相同。
数据范围: 。
题解:
问题转化一下,设第 个节点的距离是 的祖先是 ,询问 的距离是 的子孙的个数。
设 的距离是 的子孙的个数是 ,答案是 。
先预处理一下,可以倍增求 。
然后掏出树上启发式合并的板子就好了。
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll ;
typedef pair<int , int> pii ;
const int maxn = 1e5 + 5;
int n , m , c[maxn] ;
string s[maxn] ;
vector<int> rt ;
struct Link
{
int num , head[maxn] ;
struct Edge
{
int v , next ;
} edge[maxn << 1] ;
void init()
{
num = 0 ;
memset(head , -1 , sizeof(head)) ;
}
void add_edge(int u , int v)
{
edge[num].v = v ;
edge[num].next = head[u] ;
head[u] = num ++ ;
}
} link ;
struct Dsu
{
int siz[maxn] , son[maxn] ;
int dep[maxn] ;
int flag ;
int ans[maxn] , cnt[maxn] ;
int fa[maxn][22] ;
vector<pii> q[maxn] ;
void init()
{
flag = 0 ;
memset(ans , 0 , sizeof(ans)) ;
memset(cnt , 0 , sizeof(cnt)) ;
memset(fa , 0 , sizeof(fa)) ;
}
void dfs1(int f , int u , int s)
{
siz[u] = 1 ;
son[u] = 0 ;
dep[u] = s ;
for(int i = 1 ; i <= 20 ; i ++)
{
int nxt = fa[u][i - 1] ;
fa[u][i] = fa[nxt][i - 1] ;
}
for(int i = link.head[u] ; i != -1 ; i = link.edge[i].next)
{
int v = link.edge[i].v ;
if(v == f) continue ;
fa[v][0] = u ;
dfs1(u , v , s + 1) ;
siz[u] += siz[v] ;
if(siz[v] > siz[son[u]]) son[u] = v ;
}
}
void add(int f , int u , int x)
{
cnt[dep[u]] += x ;
for(int i = link.head[u] ; i != -1 ; i = link.edge[i].next)
{
int v = link.edge[i].v ;
if(v == f || v == flag) continue ;
add(u , v , x) ;
}
}
void dfs2(int f , int u , int keep)
{
for(int i = link.head[u] ; i != -1 ; i = link.edge[i].next)
{
int v = link.edge[i].v ;
if(v == f || v == son[u]) continue ;
dfs2(u , v , 0) ;
}
if(son[u]) dfs2(u , son[u] , 1) , flag = son[u] ;
add(f , u , 1) ;
for(auto x : q[u])
{
int id = x.first ;
int k = x.second ;
ans[id] = cnt[dep[u] + k] ;
}
if(son[u]) flag = 0 ;
if(!keep) add(f , u , -1) ;
}
void print()
{
for(int i = 1 ; i <= m ; i ++)
cout << max(0 , ans[i] - 1) << ' ' ;
}
} dsu ;
int find(int v , int k)
{
int ans = 0 ;
int nxt = dsu.dep[v] - k ;
if(nxt <= 0) return 0 ;
for(int i = 20 ; i >= 0 ; i --)
{
int u = dsu.fa[v][i] ;
if(u == 0) continue ;
if(dsu.dep[u] < nxt) continue ;
v = u ;
}
return v ;
}
int main()
{
//std::ios::sync_with_stdio(false) ;
scanf("%d" , &n) ;
link.init() ;
for(int i = 1 ; i <= n ; i ++)
{
int u ;
scanf("%d" , &u) ;
if(u == 0) rt.push_back(i) ;
else link.add_edge(u , i) , link.add_edge(i , u) ;
}
dsu.init() ;
for(auto x : rt) dsu.dfs1(0 , x , 1) ;
scanf("%d" , &m) ;
for(int i = 1 ; i <= m ; i ++)
{
int v , k ;
int x ;
scanf("%d%d" , &v , &k) ;
x = find(v , k) ;
if(x == 0) continue ;
dsu.q[x].push_back(make_pair(i , k)) ;
}
for(auto x : rt) dsu.dfs2(0 , x , 0) ;
dsu.print() ;
return 0;
}