链接:戳这里
D. Appleman and Tree
time limit per test2 seconds
memory limit per test256 megabytes
inputstandard input
outputstandard output
Appleman has a tree with n vertices. Some of the vertices (at least one) are colored black and other vertices are colored white.
Consider a set consisting of k (0 ≤ k < n) edges of Appleman's tree. If Appleman deletes these edges from the tree, then it will split into (k + 1) parts. Note, that each part will be a tree with colored vertices.
Now Appleman wonders, what is the number of sets splitting the tree in such a way that each resulting part will have exactly one black vertex? Find this number modulo 1000000007 (109 + 7).
Input
The first line contains an integer n (2 ≤ n ≤ 105) — the number of tree vertices.
The second line contains the description of the tree: n - 1 integers p0, p1, ..., pn - 2 (0 ≤ pi ≤ i). Where pi means that there is an edge connecting vertex (i + 1) of the tree and vertex pi. Consider tree vertices are numbered from 0 to n - 1.
The third line contains the description of the colors of the vertices: n integers x0, x1, ..., xn - 1 (xi is either 0 or 1). If xi is equal to 1, vertex i is colored black. Otherwise, vertex i is colored white.
Output
Output a single integer — the number of ways to split the tree modulo 1000000007 (109 + 7).
Examples
input
3
0 0
0 1 1
output
2
input
6
0 1 1 0 4
1 1 0 0 1 0
output
1
input
10
0 1 2 1 4 4 4 0 8
0 0 0 1 0 1 1 0 0 1
output
27
题意:
给出根节点为0的树,每个节点被涂成黑色或者白色
现在你需要切k下(1<=k<=n-1),问有多少种切法使得这k块联通块每块都只有一个节点被涂成黑色
思路:
设置dp[i][2]状态 dp[i][0]表示当前以i为根的子树没有子节点为黑色的种类
dp[i][1]表示当前以i为根的子树存在子节点有黑色的种类(这里指的是已经分好了块的种类)
每次对于一个叶子节点v 有:
dp[i][1]=dp[v][0]*dp[i][1]+dp[v][1]*dp[i][0] 儿子v节点中白色种类*子树当前黑色种类+儿子v黑色种类*子树当前白色种类
dp[i][0]=dp[v][0]*dp[i][0] 儿子节点v中的白色种类*子树当前白色种类
对于每一个子树的遍历完 有:
如果当前子树的根节点为黑色 dp[i][1]=dp[i][0] 满足一个联通块只有一个黑色节点的条件: 当前子树根节点为黑色,所 有黑色的儿子都不能与其相连,种类数为儿子的白色种类数
如果当前子树的根节点为白色 dp[i][0]=dp[i][0]+dp[i][1] 满足所有节点没有黑色的条件:当前的种类+砍掉所有的黑色儿子的种类
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
#include<vector>
#include <ctime>
#include<queue>
#include<set>
#include<map>
#include<list>
#include<stack>
#include<iomanip>
#include<cmath>
#include<bitset>
#define mst(ss,b) memset((ss),(b),sizeof(ss))
///#pragma comment(linker, "/STACK:102400000,102400000")
typedef long long ll;
typedef long double ld;
#define mod 1000000007
#define Max 1e9
using namespace std;
struct edge{
int v,next;
}e[400100];
int a[200100],head[200100],tot=0;
int n;
void Add(int u,int v){
e[tot].v=v;
e[tot].next=head[u];
head[u]=tot++;
}
ll dp[200100][2];
void DFS(int u,int fa){
dp[u][0]=1;
dp[u][1]=0;
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].v;
if(v==fa) continue;
DFS(v,u);
dp[u][1]=dp[v][0]*dp[u][1]%mod+dp[v][1]*dp[u][0]%mod;
dp[u][0]=dp[v][0]*dp[u][0]%mod;
dp[u][1]%=mod;
/*printf("\n");
printf("u=%d v=%d\n",u,v);
printf("%I64d %I64d\n",dp[u][0],dp[u][1]);*/
}
if(a[u]==0) dp[u][0]=(dp[u][1]+dp[u][0])%mod;
else dp[u][1]=dp[u][0];
}
int main(){
mst(dp,0);
mst(head,-1);
scanf("%d",&n);
for(int i=1;i<n;i++){
int x;
scanf("%d",&x);
Add(x,i);
Add(i,x);
}
for(int i=0;i<n;i++) scanf("%d",&a[i]);
DFS(0,-1);
printf("%I64d\n",dp[0][1]);
return 0;
}