题意:
一颗
n
n
n 个节点、
n
−
1
n-1
n−1 条无向边的树,可以将树上的边任意染成黑色或者是白色。
给出
m
m
m 对点
u
、
v
u、v
u、v,对于每一对点,都要求从
u
u
u 到
v
v
v 的通路上,至少有一条黑边。
求符合
m
m
m 对点的树的个数。
思路:
首先,所有的染色方案为 a = 2 n − 1 a=2^{n-1} a=2n−1,但是所有合法的方案很难求,所以我们换个角度,求所有不合法的方案 b b b,最终答案就是 a − b a-b a−b。
假设不符合第
i
i
i 对点的情况为
S
i
S_i
Si ,那么
b
=
⋃
i
=
1
m
S
i
b=\bigcup\limits_{i=1}^{m} S_i
b=i=1⋃mSi。这时候思路是不是突然就很敞亮敞亮,该用容斥了。设
S
i
S_i
Si 中的方案数量为
∣
S
i
∣
\left| S_i \right|
∣Si∣,则
b
=
⋃
i
=
1
m
S
i
=
∑
i
=
1
m
∣
S
i
∣
−
∑
i
,
j
;
1
≤
i
<
j
≤
m
∣
S
i
∩
S
j
∣
+
…
…
+
(
−
1
)
m
−
1
∗
∣
S
i
∩
…
…
∩
S
m
∣
b=\bigcup\limits_{i=1}^{m} S_i=\sum\limits_{i=1}^m{\left| S_i \right|}-\sum\limits_{i,j;1\leq i<j\leq m}{\left| S_i\cap S_j \right|}+……+(-1)^{m-1}*\left| S_i\cap……\cap S_m \right|
b=i=1⋃mSi=i=1∑m∣Si∣−i,j;1≤i<j≤m∑∣Si∩Sj∣+……+(−1)m−1∗∣Si∩……∩Sm∣
现在只剩下一个问题没有解决了,那就是不合法方案
S
i
S_i
Si 的计算。
这时候注意一下
n
n
n 的范围,我们完全可以用二进制状态压缩的方法去存从根节点到任意节点的通路所经过的边
d
u
d_u
du,那么
d
u
⊕
d
v
d_u\oplus d_v
du⊕dv 就表示从节点
u
u
u 到节点
v
v
v 的通路所经过的边。这些边必然要染成白色,那么那些多余的边,就可以任意染色。如果
d
i
d_i
di 中有
c
n
t
cnt
cnt 条边的话,
∣
S
i
∣
=
2
n
−
1
−
c
n
t
\left| S_i \right|=2^{n-1-cnt}
∣Si∣=2n−1−cnt 。
单个数对的不合法条件数可以求出,那么对于多个不合法条件的并集,只需要对通路中的路径求并后,之后的处理类似。
时间复杂度: O ( 2 m ) O(2^{m}) O(2m)
#include <bits/stdc++.h>
using namespace std;
int n, m;
long long s[55], d[55];
int h[1010], e[1010], ne[1010], idx;
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u, int fa, long long cnt) {
s[u] = cnt;
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == fa)
continue;
dfs(j, u, cnt | (1ll << j));
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++)
h[i] = -1;
for (int i = 1, a, b; i < n; i++) {
cin >> a >> b;
add(a, b), add(b, a);
}
dfs(1, -1, 0ll);
cin >> m;
for (int i = 0, a, b; i < m; i++) {
cin >> a >> b;
d[i] = s[a] ^ s[b];
}
long long res = 1ll << (n - 1);
for (int i = 1; i < (1ll << m); i++) {
long long num = 0ll;
int cnt = 0;
for (int j = 0; j < m; j++)
if (i & (1 << j)) {
cnt++;
num |= d[j];
}
int tmp = n - 1;
while (num) {
if (num & 1)
tmp--;
num >>= 1;
}
if (cnt & 1)
res -= (1ll << tmp);
else
res += (1ll << tmp);
}
cout << res;
return 0;
}