时空限制 1000ms / 128MB
题目描述
这个游戏会给出你一棵树,这棵树有N个节点,根结点是R,系统会选中M个点 P 1 , P 2 . . . P M P_1,P_2...P_M P1,P2...PM,要Imakf回答有多少组点对 ( u i , v i ) (u_i,v_i) (ui,vi)的最近公共祖先是 P i P_i Pi 。Imakf是个小蒟蒻,他就算学了LCA也做不出,于是只好求助您了。
Imakf毕竟学过一点OI,所以他允许您把答案模$ (10^9+7)$
输入格式:
第一行 N , R , M
此后N-1行 每行两个数a,b 表示a,b之间有一条边
此后1行 M个数 表示
P
i
P_i
Pi
输出格式:
M行,每行一个数,第ii行的数表示有多少组点对 ( u i , v i ) (u_i,v_i) (ui,vi)的最近公共祖先是 P i P_i Pi
说明
N≤10000,M≤50000
题目分析
首先
s
i
z
e
[
u
]
size[u]
size[u]表示以
u
u
u为根的子树的结点个数(包括u)
设
{
v
i
}
\{v_i\}
{vi}为
p
i
p_i
pi的子节点集合
那么对于一个询问
p
i
p_i
pi的答案为
a
n
s
=
s
i
z
e
[
p
i
]
∗
2
−
1
+
∑
i
!
=
j
s
i
z
e
[
v
i
]
∗
s
i
z
e
[
v
j
]
ans=size[p_i]*2-1+\sum_{i!=j} size[v_i]*size[v_j]
ans=size[pi]∗2−1+∑i!=jsize[vi]∗size[vj]
前面
s
i
z
e
[
p
i
]
∗
2
+
1
size[p_i]*2+1
size[pi]∗2+1表示
p
i
p_i
pi的子树内每个点与
p
i
p_i
pi的LCA都是
p
i
p_i
pi
由于题目中点对是有序的,所以要乘2,再减去
(
p
i
,
p
i
)
(p_i,p_i)
(pi,pi)这一个重复的
前面
∑
i
!
=
j
s
i
z
e
[
v
i
]
∗
s
i
z
e
[
v
j
]
\sum_{i!=j} size[v_i]*size[v_j]
∑i!=jsize[vi]∗size[vj]运用乘法原理,不难理解
关键在于求这个值如果
O
(
n
2
)
O(n^2)
O(n2)枚举显然爆炸
假设
s
u
m
=
∑
s
i
z
e
[
v
i
]
sum=\sum size[v_i]
sum=∑size[vi],也就是所有
s
i
z
e
[
v
i
]
size[v_i]
size[vi]的和
那么
∑
i
!
=
j
s
i
z
e
[
v
i
]
∗
s
i
z
e
[
v
j
]
=
s
u
m
2
−
∑
s
i
z
e
[
v
i
]
2
\sum_{i!=j} size[v_i]*size[v_j]=sum^2-\sum size[v_i]^2
∑i!=jsize[vi]∗size[vj]=sum2−∑size[vi]2
这样询问就是
O
(
1
)
O(1)
O(1)回答了,只要在深搜时预处理即可
#include<iostream>
#include<cmath>
#include<algorithm>
#include<queue>
#include<cstring>
#include<cstdio>
using namespace std;
typedef long long lt;
#define sqr(x) ((x)*(x))
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return f*x;
}
const lt mod=1e9+7;
const int maxN=20010;
const int maxM=100010;
int n,m,rt;
struct node{int v,nxt;}E[maxM];
int head[maxN],tot;
lt size[maxN],sum[maxN];
void add(int u,int v)
{
E[++tot].nxt=head[u];
E[tot].v=v;
head[u]=tot;
}
void dfs(int u,int pa)
{
size[u]=1; int k=0;
for(int i=head[u];i;i=E[i].nxt)
{
int v=E[i].v;
if(v==pa) continue;
dfs(v,u);
size[u]+=size[v]; size[u]%=mod;
}
lt tt=0;
for(int i=head[u];i;i=E[i].nxt)
{
if(E[i].v==pa) continue;
sum[u]+=size[E[i].v]; sum[u]%=mod;
tt+=sqr(size[E[i].v]); tt%=mod;
}
sum[u]=sqr(sum[u])%mod;
sum[u]-=tt; sum[u]=(sum[u]%mod+mod)%mod;
}
int main()
{
n=read();rt=read();m=read();
for(int i=1;i<n;++i)
{
int u=read(),v=read();
add(u,v); add(v,u);
}
dfs(rt,0);
while(m--)
{
int u=read();lt ans=(sum[u]+size[u]*2%mod-1)%mod;
printf("%lld\n",ans);
}
return 0;
}