题目描述
lre有一棵 n + 1 n+1 n+1 个节点的树, 0 0 0 号点是树根,第 i i i 个点的父亲是 a i a_i ai 。
每个点上都可以放一个弹珠或不放。之后每一回合,lre都会把所有弹珠移动到它们所在的节点的父亲节点。
若一个节点上有大于 1 1 1 个弹珠,它们会一起被lre打爆消失。原来在 0 0 0 号节点上的弹珠则会被lre收集起来。
lre觉得这可以出一道题,就问你所有放置弹珠的方案操作完(树上没有弹珠)之后可收集到的弹珠的数目之和 ( m o d 1 0 9 + 7 ) (mod~~10^9+7) (mod 109+7) 。
题解
考虑到同一深度的贡献为 0 / 1 0/1 0/1 ,然后将方案数转化成概率好像更好计算,所以考虑暴力 dp \text{dp} dp : f [ i ] [ j ] [ 0 / 1 ] f[i][j][0/1] f[i][j][0/1] 表示 i i i 子树内,距离 i i i 深度为 j j j 的贡献为 0 / 1 0/1 0/1 的概率,转移的话 f [ i ] [ 0 ] [ 0 / 1 ] = 1 2 , f [ i ] [ j ] [ 1 ] = ∑ v f [ v ] [ j − 1 ] [ 1 ] ∏ v ′ ≠ v f [ v ′ ] [ j − 1 ] [ 0 ] , f [ i ] [ j ] [ 0 ] = 1 − f [ i ] [ j ] [ 1 ] f[i][0][0/1]=\frac{1}{2},f[i][j][1]=\sum_{v} f[v][j-1][1] \prod_{v' \ne v} f[v'][j-1][0],f[i][j][0]=1-f[i][j][1] f[i][0][0/1]=21,f[i][j][1]=∑vf[v][j−1][1]∏v′=vf[v′][j−1][0],f[i][j][0]=1−f[i][j][1] 。然后这样是 O ( n 2 ) O(n^2) O(n2) 的。考虑优化,如果我们把深度最深的儿子的 dp \text{dp} dp 值直接继承到 i i i ,然后剩下的暴力合并上去的话就是 O ( n ) O(n) O(n) 的效率了。具体证明的话就是每个点对答案的贡献最多只有在某一次会被暴力合并。(by 苏ak:这和长链剖分是一样的)。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=2e5+5,P=1e9+7;
int n,hd[N],V[N],nx[N],t,f[N],p[N],d[N],id[N];
struct O{int x,y;};vector<O>g[N];
int K(int x,int y){
int z=1;
for (;y;y>>=1,x=1ll*x*x%P)
if (y&1) z=1ll*z*x%P;
return z;
}
void add(int u,int v){
nx[++t]=hd[u];V[hd[u]=t]=v;
}
void upd(int x,int u){
for (int i=1;i<=x;i++)
p[i]=1ll*g[id[u]][d[u]-i].x*p[i]%P;
}
void cal(int x,int u){
for (int i=1;i<=x;i++)
(f[i]+=1ll*K(g[id[u]][d[u]-i].x,P-2)*p[i]%P*g[id[u]][d[u]-i].y%P)%=P;
}
void dfs(int u){
int son=n+1,x=0;
for (int v,i=hd[u];i;i=nx[i]){
dfs(v=V[i]);
if (d[v]>d[son]) son=v;
}
if (son>n) id[u]=++t,d[u]=1;
else{
id[u]=id[son];d[u]=d[son]+1;
for (int i=hd[u];i;i=nx[i])
if (V[i]!=son) x=max(d[V[i]],x);
for (int i=1;i<=x;i++) p[i]=1,f[i]=0;
for (int i=hd[u];i;i=nx[i]){
if (V[i]==son) upd(x,son);
else upd(d[V[i]],V[i]);
}
for (int i=hd[u];i;i=nx[i]){
if (V[i]==son) cal(x,son);
else cal(d[V[i]],V[i]);
}
for (int i=1;i<=x;i++)
g[id[u]][d[son]-i].y=f[i],
g[id[u]][d[son]-i].x=(P+1-f[i])%P;
}
g[id[u]].push_back((O){(P+1)>>1,(P+1)>>1});
}
int main(){
cin>>n;
for (int i=1,x;i<=n;i++)
scanf("%d",&x),add(x,i);
dfs(t=0);int x=0;
for (int i=0;i<d[0];i++)
(x+=g[id[0]][i].y)%=P;
cout<<1ll*x*K(2,n+1)%P<<endl;
return 0;
}