Paint Box
题意:
给出n个格子和m种颜色,从m种颜色中选出k种颜色,对n个格子进行涂色,其中k种颜色种的每一种颜色都要用上,而且相邻的格子不能涂相同的颜色。 1 ≤ n , m ≤ 1 e 9 , 1 ≤ k ≤ 1 e 6 , k ≤ n , m 1 \leq n,m \leq 1e9,1 \leq k \leq 1e6,k \leq n,m 1≤n,m≤1e9,1≤k≤1e6,k≤n,m
做法:
- 这道题应该是容斥原理
- 容斥原理这个东西很难的,我也不是很懂,只能靠猜。
- 首先我们想象一个总的集合,那就是有k种颜色然后相邻的格子不能涂相同的颜色,的方法集合 S S S: ∣ S ∣ = k × ( k − 1 ) n − 1 |S|=k \times (k-1)^{n-1} ∣S∣=k×(k−1)n−1
- 这个答案是显而易见,然后考虑怎么容斥掉不合理的部分,上面的集合 S S S没有要求,每一种颜色都用上,那么我们首先减去只用了2,3,…,k-1种颜色的方案数: A k − 1 = ( k − 1 ) × ( k − 2 ) n − 1 A_{k-1}=(k-1)\times(k-2)^{n-1} Ak−1=(k−1)×(k−2)n−1
- 上面的式子表示用了2,3,…,k-1种颜色的方案数,我们还需要乘上一个组合数,从k中选择k-1种颜色。
- 那么一般来说这样就会有交错的地方,看看有没有呢?
- 很显然一个 A A A集合中很显然会包括其他另外一个 A A A集合的子集的染色,那么很简单我们就容斥出来了,只要在容斥上 A k − 2 = ( k − 2 ) × ( k − 3 ) n − 1 A_{k-2}=(k-2)\times(k-3)^{n-1} Ak−2=(k−2)×(k−3)n−1
- 最终答案就是,只对于固定的k种颜色: a n s = ∣ S ∣ − C k k − 1 A k − 1 + C k k − 2 A k − 2 − . . . . ans=|S|-C_{k}^{k-1}A_{k-1}+C_{k}^{k-2}A_{k-2}-.... ans=∣S∣−Ckk−1Ak−1+Ckk−2Ak−2−....
- 就这样就猜出来了,注意最后还要从m中选取k个。
代码:
#include <bits/stdc++.h>
using namespace std;
#define SZ(x) ((int)((x).size()))
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define pii pair<int,int>
#define pll pair<long long,long long>
#define rep(i, a, b) for(int i=(a);i<=(b);++i)
#define per(i, a, b) for(int i=(a);i>=(b);--i)
#define pb push_back
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 1e6 + 10;
const ll mod = 1e9 + 7;
int ksm(int a, ll n) {
int ans = 1;
while (n) {
if (n & 1) ans = 1ll * ans * a % mod;
n >>= 1;
a = 1ll * a * a % mod;
}
return ans;
}
inline int add(int u, int v) { return (u += v) >= mod ? u - mod : u; }
inline int sub(int u, int v) { return (u -= v) < 0 ? u + mod : u; }
inline int mul(int u, int v) { return 1ll * u * v % mod; }
int inv[maxn], fac[maxn];
int Com(int n, int m) {
int ret = 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod;
return ret;
}
int n, m, k;
int main() {
fac[0] = inv[0] = 1;
rep(i, 1, maxn - 1) fac[i] = mul(fac[i - 1], i);
inv[maxn - 1] = ksm(fac[maxn - 1], mod - 2);
per(i, maxn - 2, 1) inv[i] = mul(inv[i + 1], i + 1);
int T;
cin >> T;
while (T--) {
cin >> n >> m >> k;
if (k == 1) {
if (n == 1) cout << m << endl;
else cout << "0\n";
continue;
}
int ans = mul(k, ksm(k - 1, n - 1)), f = 1;
per(i, k - 1, 0) {
int tmp = 1ll * Com(k, i) * i % mod * ksm(i - 1, n - 1) % mod;
if (f == 1) ans = sub(ans, tmp);
else ans = add(ans, tmp);
f ^= 1;
}
int cmk = 1;
for (int i = m, j = 1; j <= k; j++, i--) {
cmk = 1ll * cmk * i % mod;
}
ans = 1ll * ans * cmk % mod * inv[k] % mod;
cout << ans << endl;
}
return 0;
}