题目大意
有一棵
n
n
n 个节点的有根树,有
m
m
m 组询问
每次询问给出
a
,
k
a,k
a,k,求有多少个三元组
(
a
,
b
,
c
)
(a,b,c)
(a,b,c),满足
a
,
b
a,b
a,b 都是
c
c
c 的祖先,并且
a
,
b
a,b
a,b 之间的距离不超过
k
k
k
刚开始没有思路,看了题解的分类讨论后发现可以很轻松地转化为线段树合并
解题思路
首先,定义
d
x
d_x
dx 为点
x
x
x 的深度(根节点的深度为
1
1
1)
s
x
s_x
sx 为点
x
x
x 的子树大小(包括
x
x
x)
考虑两种情况
- b b b 是 a a a 的祖先:那么 b b b 就必须在 a a a 到根的路径上。可能方案数就是 min ( d a − 1 , k ) × ( s a − 1 ) \min (d_a-1,k) \times (s_a-1) min(da−1,k)×(sa−1)
- b b b 在 a a a 的子树中:方案数就是 ∑ d x ∈ [ d a + 1 , d a + k ] ( s x − 1 ) \sum\limits_{d_x \in [d_a+1,d_a+k]} (s_x-1) dx∈[da+1,da+k]∑(sx−1),注意这里的 x x x 必须在 a a a 的子树中。
我们可以用线段树合并维护子树信息来实现操作2,最终的答案就是两方案数之和
#include<cstdio>
#include<iostream>
using namespace std;
const long long Maxn=300000+10,inf=0x3f3f3f3f;
const long long Maxm=6000000+10;
long long d[Maxn],s[Maxn];
long long nxt[Maxn<<1],to[Maxn<<1];
long long sum[Maxm],ls[Maxm],rs[Maxm];
long long root[Maxn],head[Maxn];
long long n,m,idcnt,edgecnt=1;
inline void add(long long x,long long y)
{
++edgecnt;
nxt[edgecnt]=head[x];
to[edgecnt]=y;
head[x]=edgecnt;
}
inline long long read()
{
long long s=0,w=1;
char ch=getchar();
while(ch<'0' || ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0' && ch<='9')s=(s<<3)+(s<<1)+(ch^48),ch=getchar();
return s*w;
}
inline void push_up(long long x)
{
sum[x]=sum[ls[x]]+sum[rs[x]];
}
void modify(long long &x,long long l,long long r,long long pos,long long v)
{
if(!x)x=++idcnt;
if(l==r)
{
sum[x]+=v;
return;
}
long long mid=(l+r)>>1;
if(pos<=mid)modify(ls[x],l,mid,pos,v);
else modify(rs[x],mid+1,r,pos,v);
push_up(x);
}
long long merge(long long x,long long y)
{
if(!x || !y)return x|y;
long long cur=++idcnt;
sum[cur]=sum[x]+sum[y];
ls[cur]=merge(ls[x],ls[y]);
rs[cur]=merge(rs[x],rs[y]);
return cur;
}
long long query(long long k,long long l,long long r,long long x,long long y)
{
if(x<=l && r<=y)return sum[k];
long long mid=(l+r)>>1,ret=0;
if(x<=mid)ret=query(ls[k],l,mid,x,y);
if(mid<y)ret+=query(rs[k],mid+1,r,x,y);
return ret;
}
void dfs(long long x,long long fa)
{
d[x]=d[fa]+1,s[x]=1;
for(long long i=head[x];i;i=nxt[i])
{
long long y=to[i];
if(y==fa)continue;
dfs(y,x),s[x]+=s[y];
root[x]=merge(root[x],root[y]);
}
modify(root[x],1,n,d[x],s[x]-1);
}
int main()
{
// freopen("in.txt","r",stdin);
n=read(),m=read();
for(long long i=1;i<n;++i)
{
long long x=read(),y=read();
add(x,y),add(y,x);
}
dfs(1,0);
while(m--)
{
long long x=read(),k=read();
long long ans=min(k,d[x]-1)*(s[x]-1);
ans+=query(root[x],1,n,d[x]+1,min(n,d[x]+k));
printf("%lld\n",ans);
}
return 0;
}