Description
设T 为一棵有根树,我们做如下的定义:
• 设a和b为T 中的两个不同节点。如果a是b的祖先,那么称“a比b不知道高明到哪里去了”。
• 设a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定常数x,那么称“a 与b 谈笑风生”。
给定一棵n个节点的有根树T,节点的编号为1 到 n,根节点为1号节点。你需要回答q 个询问,询问给定两个整数p和k,问有多少个有序三元组(a;b;c)满足:
1. a、b和 c为 T 中三个不同的点,且 a为p 号节点;
2. a和b 都比 c不知道高明到哪里去了;
3. a和b 谈笑风生。这里谈笑风生中的常数为给定的 k。
Input
输入文件的第一行含有两个正整数n和q,分别代表有根树的点数与询问的个数。接下来n - 1行,每行描述一条树上的边。每行含有两个整数u和v,代表在节点u和v之间有一条边。
接下来q行,每行描述一个操作。第i行含有两个整数,分别表示第i个询问的p和k。
Output
输出 q 行,每行对应一个询问,代表询问的答案。
Sample Input
5 3
1 2
1 3
2 4
4 5
2 2
4 1
2 3
Sample Output
3
1
3
HINT
1<=P<=N
1<=K<=N
N<=300000
Q<=300000
题解
主席树的新用法get√
分析题意:
给定一个点,求与其树上距离不超过k,且与该点互为父子关系的点的个数乘以以这两点为祖先的点的个数,即
对于左边,我们可以在求出原树的dfs序后直接计算得出;对于右边,我们可以按照dfs序向主席树中加入每个点的size值,然后通过在主席树上计算得出。
主席树的功能:记录区间内某一深度的所有点的size之和。
CODE:
#include<cstdio>
typedef long long ll;
const int N=3e5+10;
struct edge
{
int nxt,to;
}a[N<<1];
struct tree
{
int l,r;
ll num;
}t[N*20];
int head[N],deep[N],root[N];
int s[N],b[N],e[N];
ll size[N*20];
int n,m,x,y,num,tot,cnt;
ll ans;
inline int min(const int &a,const int &b){return a<b?a:b;}
inline void read(int &n)
{
n=0;char c=getchar();
while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') n=n*10+c-48,c=getchar();
}
inline void add(int x,int y)
{
a[++num].nxt=head[x],a[num].to=y,head[x]=num;
a[++num].nxt=head[y],a[num].to=x,head[y]=num;
}
void dfs(int now,int fa,int depth)
{
deep[now]=depth;
size[now]=1;
b[now]=++tot;
s[tot]=now;
for(int i=head[now];i;i=a[i].nxt)
if(a[i].to!=fa)
{
dfs(a[i].to,now,depth+1);
size[now]+=size[a[i].to];
}
e[now]=tot;
}
void Add(int l,int r,int pos,int &now,int pre,int num)
{
now=++cnt;
t[now].num=t[pre].num+num;
if(l==r) return;
int mid=(l+r)>>1;
if(pos<=mid) Add(l,mid,pos,t[now].l,t[pre].l,num),t[now].r=t[pre].r;
else Add(mid+1,r,pos,t[now].r,t[pre].r,num),t[now].l=t[pre].l;
}
ll ask(int L,int R,int l,int r,int now,int pre)
{
if(L<=l&&r<=R) return t[now].num-t[pre].num;
int mid=(l+r)>>1;
ll ans=0;
if(L<=mid) ans+=ask(L,R,l,mid,t[now].l,t[pre].l);
if(R>mid) ans+=ask(L,R,mid+1,r,t[now].r,t[pre].r);
return ans;
}
int main()
{
read(n),read(m);
for(int i=1;i<n;i++)
read(x),read(y),add(x,y);
dfs(1,0,1);
for(int i=1;i<=n;i++)
Add(1,n,deep[s[i]],root[i],root[i-1],size[s[i]]-1);
while(m--)
{
read(x),read(y);
ans=(ll)min(deep[x]-1,y)*(size[x]-1);
ans+=ask(deep[x]+1,min(deep[x]+y,n),1,n,root[e[x]],root[b[x]-1]);
printf("%lld\n",ans);
}
return 0;
}