题意
给定一个 n n n 个顶点的树,每个顶点有黑色和白色两种颜色,删掉 k k k 条边,将这棵树分成 k + 1 k+1 k+1 部分,要求每个部分有且仅有一个黑色的点,求分割方案数(对 1 0 9 + 7 10^9+7 109+7 取模)
题解
对于一个结点 u u u 而言,如果它是黑点,那么将它子树中所有包含黑点的子树都删掉。如果不是,还是要删掉黑点,但保留一个包含黑点的子树。
定义
f
[
u
]
[
1
]
f[u][1]
f[u][1] 为子树只有一个黑点的方案数
f [ u ] [ 0 ] f[u][0] f[u][0] 为子树没有黑点的方案数
v v v 为 u u u 的子节点
f
[
u
]
[
1
]
=
f
[
u
]
[
1
]
×
(
f
[
v
]
[
0
]
+
f
[
v
]
[
1
]
)
+
f
[
u
]
[
0
]
×
f
[
v
]
[
1
]
f[u][1]=f[u][1] \times(f[v][0]+f[v][1])+f[u][0]\times f[v][1]
f[u][1]=f[u][1]×(f[v][0]+f[v][1])+f[u][0]×f[v][1]
f
[
u
]
[
0
]
=
f
[
u
]
[
0
]
×
(
f
[
v
]
[
1
]
+
f
[
v
]
[
0
]
)
f[u][0]=f[u][0]\times (f[v][1]+f[v][0])
f[u][0]=f[u][0]×(f[v][1]+f[v][0])
DFS(u):
DP[u][0] = 1
DP[u][1] = 0
foreach v : the children of vertex u
DFS(v)
DP[u][1] *= DP[v][0]
DP[u][1] += DP[u][0]*DP[v][1]
DP[u][0] *= DP[v][0]
if x[u] == 1:
DP[u][1] = DP[u][0]
else:
DP[u][0] += DP[u][1]
当以 u u u 为根的子树(指已经遍历过的子树)中有一个黑点时,对于结点 u u u 的子节点 v v v,若以 v v v 为根的子树中没有黑点,则将 v v v 与 u u u 连接,若有黑点,则不连接,这样得到的方案数就是 f [ v ] [ 0 ] + f [ v ] [ 1 ] f[v][0]+f[v][1] f[v][0]+f[v][1],根据乘法原理得到 f [ u ] [ 1 ] × ( f [ v ] [ 0 ] + f [ v ] [ 1 ] ) f[u][1]\times (f[v][0]+f[v][1]) f[u][1]×(f[v][0]+f[v][1]);当以 u u u 为根的子树(同上)中没有黑点时,若以 v v v 为根的子树中有黑点,则 u u u 和 v v v 必须相连,得到 d p [ u ] [ 0 ] × d p [ v ] [ 1 ] dp[u][0]\times dp[v][1] dp[u][0]×dp[v][1];再根据加法原理相加。
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
const int maxn = 1e6 + 7;
const int N = 1e5 + 7, M = N * 2;
const int inf = 0x3f3f3f;
const int mod = 1000000007;
ll e[M], ne[M], h[N], w[N], idx = 0;
ll black[N];
ll f[N][2];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u) {
for(int i = h[u]; i != -1; i = ne[i]) {
int v = e[i];
dfs(v);
f[u][1] = (( f[u][1] % mod * (f[v][0] % mod + f[v][1] % mod) % mod ) % mod + ( f[u][0] % mod * f[v][1] % mod) % mod) % mod;
f[u][0] = (f[u][0] % mod * (f[v][1] % mod + f[v][0] % mod) ) % mod;
//f[u][1] = f[u][1] * (f[v][0] + f[v][1]) + (f[u][0] * f[v][1]);
//f[u][0] = f[u][0] * (f[v][1] + f[v][0]) ;
}
}
int main() {
// ios::sync_with_stdio(false);
int n;
scanf("%d", &n);
memset(h, -1, sizeof(h));
for(int i = 1; i < n; i++) {
int p;
scanf("%d", &p);
add(p, i);
}
for(int i = 0; i < n; i++) {
int c;
scanf("%d", &c);
f[i][c] = 1;
}
dfs(0);
printf("%d", f[0][1]);
return 0;
}
/*
3
0 0
0 1 1
*/
/*
数组开够了吗 开到上界的n+1次方
初始化了吗
*/