Evensgn 剪树枝
时间限制:1s 空间限制:128MB
题目描述
繁华中学有一棵苹果树。苹果树有n 个节点(也就是苹果),n − 1 条边(也就
是树枝)。调皮的Evensgn 爬到苹果树上。他发现这棵苹果树上的苹果有两种:一
种是黑苹果,一种是红苹果。Evensgn想要剪掉 k 条树枝,将整棵树分成k + 1 个
部分。他想要保证每个部分里面有且仅有一个黑苹果。请问他一共有多少种剪树枝
的方案?
输入格式
第一行一个数字n,表示苹果树的节点(苹果)个数。
第二行一共n − 1 个数字p0, p1, p2, p3, ..., pn−2,pi表示第 i + 1 个节点和pi 节
点之间有一条边。注意,点的编号是0 到 n − 1。
第三行一共n 个数字 x0, x1, x2, x3, ..., xn−1。如果xi 是 1,表示i 号节点是黑
苹果;如果xi 是 0,表示i 号节点是红苹果。
输出格式
输出一个数字,表示总方案数。答案对109 + 7 取模。
样例输入1
3
0 0
0 1 1
6
样例输出1
2
样例输入2
6
0 1 1 0 4
1 1 0 0 1 0
样例输出2
1
样例输入3
10
0 1 2 1 4 4 4 0 8
0 0 0 1 0 1 1 0 0 1
样例输出3
27
数据范围
对于30% 的数据,1 ≤n ≤ 10。
对于60% 的数据,1 ≤n ≤ 100。
对于80% 的数据,1 ≤n ≤ 1000。
对于100% 的数据,1 ≤n ≤ 105。
对于所有数据点,都有0 ≤ pi ≤n − 1,xi = 0 或xi = 1。
特别地,60%中、80% 中、100%中各有一个点,树的形态是一条链。
题解:
其实是一个树规。
设f[i][j],f表示方案数,i表示以i为根节点的子树,j为0或1,0表示这棵子树的黑苹果数量等于砍的刀数,1代表砍的刀数比黑苹果数量少1.
为什么设这两种关系?我们可以想一下,整棵树有k个黑苹果,需要砍k-1刀,分成k个部分。如果把其中的一部分单独提出来,发现黑苹果数量比这一段砍的刀数少1,那么其他部分肯定是砍的刀数等于黑苹果数量。
接下来就是状态转移了。我们可以先跑一遍dfs,找出以i节点为根的子树中共有多少个黑苹果。如果为零,那这一段就不用搜了。因为这一棵树中反正也不能砍,对结果没有影响。
首先我们每向上走一层,把少砍一刀的情况加入砍全的情况,f[v][0]+=f[v][1](v为i的合法儿子)。接下来分两种情况,如果i节点为红,f[i][0]=∏f[v][0](v为i的合法儿子)。设sum=∏f[v][0],f[i][1]=Σ(sum/f[v][0]*f[v][1])。如果节点为黑,我们只考虑i的儿子,所以只存在f[i][1]=∏f[v][0].
最后输出f[1][1]。
注意时刻取模,sum/f[v][0]是用逆元。
附上代码
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<vector>
using namespace std;
struct tree{
int u,v,next;
}l[301000];
long long f[101000][5],mod=1000000007;
int lian[101000],e=0,n,fa[101000],size[101000],a[101000];
void bian(int,int);
void dfs(int);
void dp(int);
long long ksm(long long,long long);
int main()
{
scanf("%d",&n);
for(int i=2;i<=n;i++)
{
int x;
scanf("%d",&x);
x+=1;
bian(x,i);
bian(i,x);
}
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
}
dfs(1);
dp(1);
printf("%lld",f[1][1]);
return 0;
}
void bian(int x,int y)
{
e++;
l[e].u=x;
l[e].v=y;
l[e].next=lian[x];
lian[x]=e;
}
void dfs(int x)
{
if(a[x]!=0)
size[x]+=1;
for(int i=lian[x];i;i=l[i].next)
{
int v=l[i].v;
if(v!=fa[x])
{
fa[v]=x;
dfs(v);
size[x]+=size[v];
}
}
}
void dp(int x)
{
int num=0;
vector<int> ve;
if(a[x]==0)
{
long long sum=1;
for(int i=lian[x];i;i=l[i].next)
{
int v=l[i].v;
if(v==fa[x])
continue;
if(size[v]==0)
continue;
dp(v);
num++;
ve.push_back(v);
f[v][0]+=f[v][1];
sum*=f[v][0];
sum%=mod;
}
f[x][0]=sum;
for(int i=0;i<num;i++)
{
long long k=sum*ksm(f[ve[i]][0],mod-2)%mod;
k*=f[ve[i]][1];
k%=mod;
f[x][1]+=k;
f[x][1]%=mod;
}
}
else
{
long long sum=1;
for(int i=lian[x];i;i=l[i].next)
{
int v=l[i].v;
if(v==fa[x])
continue;
if(size[v]==0)
continue;
dp(v);
num++;
ve.push_back(v);
f[v][0]+=f[v][1];
sum*=f[v][0];
sum%=mod;
}
f[x][1]=sum;
}
}
long long ksm(long long x,long long y)
{
long long ans=1,z=x;
while(y)
{
if((y&1)==1)
{
ans*=z;
ans%=mod;
}
y=y>>1;
z*=z;
z%=mod;
}
return ans;
}