Problem
给出一个有 n n n 个点的有根树,根节点为 1 1 1,边的长度为 1 1 1。
有 q q q 个询问,询问给定两个整数 p p p 和 k k k,问有多少个有序三元组 ( a , b , c ) (a,b,c) (a,b,c) 满足:
- a a a、 b b b 和 c c c 为三个不同的点,且 a a a 为 p p p 号节点;
- a a a 和 b b b 都是 c c c 的祖先;
- a a a 和 b b b 的距离不超过 k k k。
数据范围: 1 ≤ p , k ≤ n ≤ 3 × 1 0 5 1\le p,k\le n\le3\times 10^5 1≤p,k≤n≤3×105, q ≤ 3 × 1 0 5 q\le3\times 10^5 q≤3×105。
Solution
由于 a a a 和 b b b 的父子关系不清楚,我们分 2 2 2 种情况讨论:
- b b b 是 a a a 的祖先,这一部分的贡献是 ( S i z e [ a ] − 1 ) ∗ m i n ( d e p [ a ] , k ) (Size[a]-1)∗min(dep[a],k) (Size[a]−1)∗min(dep[a],k),减一是把 a a a 自己减掉。
- b b b 是 a a a 的子孙,这一部分的贡献是 a a a 子树中深度在 k k k 以内的点的 ( S i z e − 1 ) (Size-1) (Size−1) 之和,减一的道理一样。
自己画画图算一算,上面两个还是很好理解的。
定义 f [ x ] [ i ] f[x][i] f[x][i] 表示在 x x x 子树中,深度大于等于 j j j 的所有点的 S i z e − 1 Size-1 Size−1 之和(就相当于一个后缀和)。
f f f 的第二维只和深度有关系,用长链剖分可以优化到 O ( n ) O(n) O(n)。
Code
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#define N 300005
#define ll long long
using namespace std;
int n,q,p,k,t;
int first[N],v[N<<1],nxt[N<<1];
int fa[N],dep[N],son[N],len[N],Size[N];
ll temp[N],*f[N],*pos=temp,ans[N];
vector<pair<int,int> >Q[N];
void add(int x,int y){
nxt[++t]=first[x],first[x]=t,v[t]=y;
}
void dfs(int x,int fa){
Size[x]=1;
for(int i=first[x];i;i=nxt[i]){
int to=v[i];
if(to==fa) continue;
dep[to]=dep[x]+1;
dfs(to,x),Size[x]+=Size[to];
if(len[to]>len[son[x]]) son[x]=to;
}
len[x]=len[son[x]]+1;
}
void dp(int x,int fa){
if(son[x]) f[son[x]]=f[x]+1,dp(son[x],x);
for(int i=first[x];i;i=nxt[i]){
int to=v[i];
if(to==fa||to==son[x]) continue;
f[to]=pos,pos+=len[to],dp(to,x);
for(int j=0;j<len[to];++j) f[x][j+1]+=f[to][j];
}
for(int i=0;i<Q[x].size();++i){
int k=Q[x][i].first,id=Q[x][i].second;
ans[id]+=1ll*(Size[x]-1)*min(dep[x]-1,k);
if(k>=len[x]-1) ans[id]+=f[x][1];
else ans[id]+=f[x][1]-f[x][k+1];
}
f[x][0]=f[x][1]+Size[x]-1;
}
int main(){
int x,y,i;
scanf("%d%d",&n,&q);
for(i=1;i<n;++i){
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
for(i=1;i<=q;++i){
scanf("%d%d",&x,&y);
Q[x].push_back(make_pair(y,i));
}
dep[1]=1,dfs(1,0);
f[1]=pos,pos+=len[1];
dp(1,0);
for(i=1;i<=q;++i) printf("%lld\n",ans[i]);
return 0;
}