其实记前缀积和后缀积就可以了,没必要求逆元啊。
#include<cstdio>
#include<algorithm>
#include<cstring>
#define R register
#define ll long long
#define max_n 100010
#define mod 1000000007
using namespace std;
struct ED{int to,nex;}edge[200100];
int n,et;
int col[max_n],st[max_n];
ll f[max_n][2];
int read()
{
R int xx;R char ch;
while(ch=getchar(),ch<'0'||ch>'9');xx=ch-'0';
while(ch=getchar(),ch>='0'&&ch<='9')xx=xx*10+ch-'0';
return xx;
}
void exgcd(ll &x,ll &y,ll a,ll b)
{
if(b==0)x=1,y=0;
else
{
exgcd(x,y,b,a%b);
ll tmp=x;
x=(y+mod)%mod;
y=(tmp-(a/b)*y%mod+mod)%mod;
}
}
void dfs(int u,int pre)
{
R int e,v;
R ll x,y;
R ll sum1=0,sum2=1;
for(e=st[u];e!=-1;e=edge[e].nex)
if((v=edge[e].to)!=pre)
{
dfs(v,u);
sum2=sum2*(f[v][1]+f[v][0])%mod;
}
if(col[u]==1){f[u][1]=sum2,f[u][0]=0;return;}
for(e=st[u];e!=-1;e=edge[e].nex)
if((v=edge[e].to)!=pre)
{
exgcd(x,y,(f[v][0]+f[v][1])%mod,mod);
sum1=(sum1+f[v][1]*(sum2*x%mod)%mod)%mod;
}
f[u][1]=sum1%mod,f[u][0]=sum2%mod;
}
int main()
{
R int i,j;
R int u,v;
n=read();
memset(st,-1,sizeof(st));
et=0;
for(i=1;i<=n;++i)col[i]=read();
for(i=1;i<n;++i)
{
u=read(),v=read();
edge[et]=(ED){v,st[u]},st[u]=et++;
edge[et]=(ED){u,st[v]},st[v]=et++;
}
dfs(1,1);
printf("%lld",f[1][1]);
return 0;
}