题目描述:![在这里插入图片描述](https://img-blog.csdnimg.cn/20200525171637516.png)
3
≤
n
≤
1
0
5
3\le n\le10^5
3≤n≤105
题目分析:
最后
v
a
l
n
−
1
,
0
val_{n-1,0}
valn−1,0和
v
a
l
n
−
1
,
1
val_{n-1,1}
valn−1,1的递推式只和
≤
n
−
1
\le n-1
≤n−1的项有关(
v
a
l
n
,
∗
=
0
val_{n,*}=0
valn,∗=0),而根据
v
a
l
n
−
2
val_{n-2}
valn−2的递推式已经求出了
v
a
l
n
−
1
,
∗
=
A
∗
v
a
l
1
,
0
+
B
∗
v
a
l
1
,
1
+
C
val_{n-1,*}=A*val_{1,0}+B*val_{1,1}+C
valn−1,∗=A∗val1,0+B∗val1,1+C,所以可以列出两个等式求解。
注意这里选择的点要算次数的话必须满足选择后局面没有结束(也就是说它还可以走),所以 v a l 1 , 1 val_{1,1} val1,1和 v a l n − 1 , 0 val_{n-1,0} valn−1,0的递推式中不加 1 n \frac 1n n1。
最后的答案还要加上选择一个起点的并走一步的期望移动距离,即
1
n
2
∗
∑
d
i
s
[
i
]
\frac 1{n^2}*\sum dis[i]
n21∗∑dis[i]
d
i
s
[
i
]
dis[i]
dis[i]是点
i
i
i到其他所有点的距离和,可以两遍
d
f
s
dfs
dfs求出。
Code(代码中的 f i , ∗ f_{i,*} fi,∗表示的是 i i i个0的时候的情况,与题解意义相反):
#include<bits/stdc++.h>
#define maxn 100005
using namespace std;
const int mod = 1e9+7;
int n,m,b[2],siz[maxn],dis[maxn],F[2],G[2],ans,sum,inv[maxn];
char a[maxn];
int fir[maxn],nxt[maxn<<1],to[maxn<<1],tot;
void line(int x,int y){nxt[++tot]=fir[x],fir[x]=tot,to[tot]=y;}
void dfs1(int u,int ff){
siz[u]=1;
for(int i=fir[u],v;i;i=nxt[i]) if((v=to[i])!=ff)
dfs1(v,u),dis[u]=(dis[u]+dis[v]+siz[v])%mod,siz[u]+=siz[v];
}
void dfs2(int u,int ff){
for(int i=fir[u],v;i;i=nxt[i]) if((v=to[i])!=ff)
dis[v]=(dis[u]-siz[v]+n-siz[v])%mod,dfs2(v,u);
}
int Pow(int a,int b){int s=1;for(;b;b>>=1,a=1ll*a*a%mod) if(b&1) s=1ll*s*a%mod;return s;}
struct node{
int A,B,C; node(int A=0,int B=0,int C=0):A(A),B(B),C(C){}
node operator - (node p){return node((A-p.A)%mod,(B-p.B)%mod,(C-p.C)%mod);}
node operator + (node p){return node((A+p.A)%mod,(B+p.B)%mod,(C+p.C)%mod);}
node operator * (int t){return node(1ll*A*t%mod,1ll*B*t%mod,1ll*C*t%mod);}
int calc(){return (1ll*A*F[0]+1ll*B*F[1]+C)%mod;}
}f[maxn][2];
int main()
{
scanf("%d%s",&n,a+1);
for(int i=1;i<=n;i++) b[a[i]-'0']++;
for(int i=2,x;i<=n;i++) scanf("%d",&x),line(x,i),line(i,x);
dfs1(1,0),dfs2(1,0);
f[1][0]=node(1,0,0),f[1][1]=node(0,1,0);
inv[0]=inv[1]=1;
for(int i=2;i<=n;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
for(int i=1;i<n-1;i++){
f[i+1][0]=(f[i][0]*n-f[i-1][0]*(i-1)-(f[i-1][1]+node(0,0,i!=1)))*inv[n-i];
f[i+1][1]=(f[i][1]*n-f[i-1][1]*i-(f[i+1][0]+node(0,0,1)))*inv[n-i-1];
}
node P = (f[n-2][0]*(n-2)+f[n-2][1]+node(0,0,1))*inv[n]-f[n-1][0];
//cout<<P.A<<' '<<P.B<<' '<<P.C<<endl;
node Q = f[n-2][1]*(1ll*(n-1)*inv[n]%mod)-f[n-1][1];
//cout<<Q.A<<' '<<Q.B<<' '<<Q.C<<endl;
if(!P.B) swap(P,Q);
P = P - Q*(1ll*P.A*Pow(Q.A,mod-2)%mod);
F[1]=1ll*-P.C*Pow(P.B,mod-2)%mod;
F[0]=1ll*-(Q.C+1ll*Q.B*F[1])%mod*Pow(Q.A,mod-2)%mod;
G[0]=f[b[0]][0].calc(),G[1]=f[b[0]][1].calc();
for(int i=1;i<=n;i++) ans=(ans+1ll*G[a[i]-'0']*dis[i])%mod, sum=(sum+dis[i])%mod;
printf("%d\n",((1ll*sum*inv[n]+ans)%mod*inv[n]%mod+mod)%mod);
}