题意:将一个树切断一些边,使得每个子树都只有一个黑色的顶点,问有多少种切法。
思路:dp[u][0]代表这个顶点连接下面使得它所在的子树没有黑色顶点的情况,dp[u][1]代表这个顶点连接下面使得它所在的子树只有一个黑色顶点的情况。
AC代码如下:
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
using namespace std;
typedef long long ll;
void exgcd(ll a,ll b,ll& d,ll& x,ll& y)
{
if(!b)d=a,x=1LL,y=0LL;
else exgcd(b,a%b,d,y,x),y-=x*(a/b);
}
ll inv(ll a,ll m)
{
ll d,x,y;
exgcd(a,m,d,x,y);
return d==1LL?(x+m)%m:-1LL;
}
vector<int> vc[100010];
ll dp[100010][2],MOD=1000000007;
int root,vis[100010],col[100010];
void dfs(int u)
{ int i,j,k,len=vc[u].size();
if(col[u]==1)
{ dp[u][1]=1;
for(i=0;i<len;i++)
{ dfs(vc[u][i]);
if(col[vc[u][i]]==1)
dp[u][1]*=dp[vc[u][i]][1];
else
dp[u][1]*=(dp[vc[u][i]][0]+dp[vc[u][i]][1]);
dp[u][1]%=MOD;
}
}
else if(col[u]==0)
{ dp[u][0]=1;
for(i=0;i<len;i++)
{ dfs(vc[u][i]);
if(col[vc[u][i]]==1)
dp[u][0]*=dp[vc[u][i]][1];
else
dp[u][0]*=(dp[vc[u][i]][0]+dp[vc[u][i]][1]);
dp[u][0]%=MOD;
}
for(i=0;i<len;i++)
{ if(col[vc[u][i]]==1)
dp[u][1]+=dp[u][0]*inv(dp[vc[u][i]][1],MOD)%MOD*dp[vc[u][i]][1];
else
dp[u][1]+=dp[u][0]*inv(dp[vc[u][i]][0]+dp[vc[u][i]][1],MOD)%MOD*dp[vc[u][i]][1];
dp[u][1]%=MOD;
}
}
dp[u][1]%=MOD;
dp[u][0]%=MOD;
}
void solve()
{ int T,t,n,m,i,j,k=0,u;
scanf("%d",&n);
for(i=1;i<n;i++)
{ scanf("%d",&u);
vc[u].push_back(i);
}
for(i=0;i<n;i++)
{ scanf("%d",&col[i]);
k+=col[i];
}
if(k==0)
{ printf("0\n");
return;
}
if(k==1)
{ printf("1\n");
return;
}
dfs(0);
printf("%I64d\n",dp[0][1]);
}
int main()
{ solve();
}