简单介绍二项式反演:
如果存在
a
n
=
∑
i
=
s
n
(
n
i
)
b
i
a_n = \sum_{i = s}^n\left(\begin{array}{l} n \\ i \end{array}\right)b_i
an=∑i=sn(ni)bi则可以得到
b
n
=
∑
i
=
s
n
(
−
1
)
n
−
i
(
n
i
)
a
i
b_n = \sum_{i = s} ^ n(-1)^{n - i}\left(\begin{array}{l} n \\ i \end{array}\right)a_i
bn=∑i=sn(−1)n−i(ni)ai
题意:给出一个含有n个节点的树,以及k个颜色,询问有多少种方式正好用k个颜色给树染色,并且任意两个相邻的节点颜色不同。
题解:我们发现只要一个节点与他的父节点颜色不同即可,所以对于根节点有 k ∗ ( k − 1 ) n − 1 k*(k - 1)^{n - 1} k∗(k−1)n−1种方案,我们假设 f ( n ) f(n) f(n)为恰好用n种颜色染色整棵树的方案数, g ( n ) g(n) g(n)为至多用n种颜色染色整棵树的方案数,显然 g ( n ) = ∑ i = 2 n ( n i ) f ( i ) g(n) = \sum_{i = 2}^{n}\left(\begin{array}{l} n \\ i \end{array}\right)f(i) g(n)=∑i=2n(ni)f(i)由二项式反演可以得到 f ( n ) = ∑ i = 2 n ( − 1 ) n − i ( n i ) g ( i ) f(n) = \sum_{i = 2} ^ n(-1)^{n - i}\left(\begin{array}{l} n \\ i \end{array}\right)g(i) f(n)=∑i=2n(−1)n−i(ni)g(i),我们可以轻松的知道由k种颜色颜色整个树的方案数为 k ∗ ( k − 1 ) n − 1 k * (k - 1)^{n - 1} k∗(k−1)n−1,因为只有根节点可以选择k种颜色,其余节点都要保持与他的父节点保持不同,然后就可以再 O ( n ∗ k ) O(n*k) O(n∗k)的复杂度内解决此问题。
a c c o d e : ac code: accode:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define met(a, b) memset(a, b, sizeof(a))
#define rep(i, a, b) for(int i = a; i <= b; i++)
#define per(i, a, b) for(int i = a; i >= b; i--)
#define fi first
#define se second
#define pb push_back
const int maxn = 1e5 + 10;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
ll qp(ll base, ll n) {ll res = 1; while(n) {if(n & 1) res = (res * base) % mod; base = (base * base) % mod; n >>= 1;} return res;};
ll inv(ll x) {return qp(x, mod - 2);}
ll fac(ll n) {ll res = 1; for(int i = 1; i <= n; i++) res = (res * i) % mod; return res;}
ll C(ll n, ll m){return fac(n) * inv(fac(m)) % mod * inv(fac(n - m)) % mod;}
vector<int> G[maxn];
ll dp[maxn];
ll dfs(int u, int col) {
dp[u] = 1;
for(auto v : G[u]) {
dfs(v, col);
dp[u] = dp[u] * dp[v] % mod * max(col - 1, 0) % mod;
}
return dp[u] * col % mod;
}
int main() {
int n, k, p;
while(~scanf("%d%d", &n, &k)) {
rep(i, 1, maxn - 1) G[i].clear();
rep(i, 1, n - 1) {
scanf("%d", &p);
G[p].pb(i);
}
ll sum = 0, f = 1;
rep(i, 0, k) {
f = (i % 2) ? -1 : 1;
sum = (sum + f * C(k, i) % mod * dfs(0, k - i) % mod + mod) % mod;
}
printf("%lld\n", sum);
}
return 0;
}