Solution
- 一道组合计数好题,orz zzq。
- 对原问题进行转化,答案 = ∑ i = 1 n f ( a n s = i ) × i = ∑ i = 1 n f ( a n s ≥ i ) = ∑ i = 1 n ( m n − f ( a n s < i ) ) = n × m n − ∑ i = 0 n − 1 f ( a n s ≤ i ) ( f 为 方 案 数 ) = \sum \limits_{i = 1}^{n} f(ans = i) \times i = \sum \limits_{i = 1}^{n} f(ans \ge i) = \sum \limits_{i = 1}^{n}(m ^ n - f(ans < i)) = n \times m^n - \sum \limits_{i = 0}^{n - 1}f(ans \le i)(f为方案数) =i=1∑nf(ans=i)×i=i=1∑nf(ans≥i)=i=1∑n(mn−f(ans<i))=n×mn−i=0∑n−1f(ans≤i)(f为方案数)。
- 考虑如何计算 ∑ i = 0 n − 1 f ( a n s ≤ i ) \sum \limits_{i = 0}^{n - 1}f(ans \le i) i=0∑n−1f(ans≤i)。
- 先把序列分成 j j j 个块,相邻的块颜色不同,染色的方案数为 m ( m − 1 ) j − 1 m(m - 1)^{j - 1} m(m−1)j−1。
- 计算分成 j j j 个块的方案要考虑 a n s ≤ i ans \le i ans≤i 的限制,直接计算并不容易,我们通过容斥,枚举 k k k 个长度 > i > i >i 的块,把它们的长度减去 i i i,剩下的部分就可以直接用隔板法计算。
- 因此 ∑ i = 0 n − 1 f ( a n s ≤ i ) = ∑ i = 0 n − 1 ∑ j = 1 n m ( m − 1 ) j − 1 ∑ k = 0 j ( − 1 ) k C j k C n − i k − 1 j − 1 \sum \limits_{i = 0}^{n - 1} f(ans \le i) = \sum \limits_{i = 0}^{n - 1} \sum \limits_{j = 1}^{n}m(m - 1)^{j - 1} \sum \limits_{k = 0}^{j} (-1)^kC_{j}^{k}C_{n - ik - 1}^{j - 1} i=0∑n−1f(ans≤i)=i=0∑n−1j=1∑nm(m−1)j−1k=0∑j(−1)kCjkCn−ik−1j−1。
- 考虑简化这个式子,把无关项外移,先枚举 k k k 再枚举 j j j,得到 = m ∑ i = 0 n − 1 ∑ k = 0 n ( − 1 ) k ∑ j = max { k , 1 } n ( m − 1 ) j − 1 C j k C n − i k − 1 j − 1 =m\sum \limits_{i = 0}^{n - 1}\sum\limits_{k = 0}^{n}(-1)^k\sum \limits_{j = \max\{k, 1\}}^{n}(m - 1)^{j - 1}C_{j}^{k}C_{n - ik - 1}^{j - 1} =mi=0∑n−1k=0∑n(−1)kj=max{k,1}∑n(m−1)j−1CjkCn−ik−1j−1
- 考虑组合数的限制 k − 1 ≤ j − 1 ≤ n − i k − 1 k - 1 \le j - 1 \le n - ik - 1 k−1≤j−1≤n−ik−1,所以 k ≤ n − i k , ( i + 1 ) k ≤ n k \le n - ik, (i + 1)k \le n k≤n−ik,(i+1)k≤n,枚举 i , k i, k i,k 的复杂度为调和级数。
- 现在只要能够快速计算枚举 j j j 的部分,问题就能解决了。
- 先做一些简单的变换: C n − i k − 1 j − 1 = ( n − i k − 1 ) ! ( j − 1 ) ! ( n − i k − j ) ! = j n − i k × ( n − i k ) ! j ! × ( n − i k − j ) ! = j n − i k C n − i k j C_{n - ik - 1}^{j - 1} = \frac{(n - ik - 1)!}{(j - 1)!(n - ik - j)!} = \frac{j}{n - ik} \times \frac{(n - ik)!}{j! \times (n - ik - j)!} = \frac{j}{n - ik}C_{n - ik}^{j} Cn−ik−1j−1=(j−1)!(n−ik−j)!(n−ik−1)!=n−ikj×j!×(n−ik−j)!(n−ik)!=n−ikjCn−ikj
- 代入原式,得到 = m ∑ i = 0 n − 1 ∑ k = 0 n ( − 1 ) k n − i k ∑ j = max { k , 1 } n j ( m − 1 ) j − 1 C j k C n − i k j =m\sum \limits_{i = 0}^{n - 1}\sum\limits_{k = 0}^{n}\frac{(-1)^k}{n - ik}\sum \limits_{j = \max\{k, 1\}}^{n}j(m - 1)^{j - 1}C_{j}^{k}C_{n - ik}^{j} =mi=0∑n−1k=0∑nn−ik(−1)kj=max{k,1}∑nj(m−1)j−1CjkCn−ikj
- 考虑寻找枚举 j j j 部分的组合意义:
- 在 n − i k n - ik n−ik 个点中选 j j j 个点( C n − i k j C_{n - ik}^{j} Cn−ikj);
- 在 j j j 个点中选 k k k 个点( C j k C_{j}^{k} Cjk);
- 在 j j j 个点中选一个关键点( j j j);
- j j j 个点中除关键点外的点用 m − 1 m - 1 m−1 种颜色染色( ( m − 1 ) j − 1 (m - 1)^{j - 1} (m−1)j−1)。
- 注意到若该组合意义下存在方案,要满足 j j j 至少为 1。
- 考虑先枚举选 k k k 个点:
- 在 n − i k n - ik n−ik 个点中选 k k k 个点 ( C n − i k k C_{n - ik}^{k} Cn−ikk);
- 分两种情况讨论:
- 关键点不在
k
k
k 个点之中:
[1] 将 k k k 个点用 m − 1 m - 1 m−1 种颜色染色( ( m − 1 ) k (m - 1)^{k} (m−1)k);
[2] 在 n − i k − k n - ik - k n−ik−k 个点中选一个关键点( n − i k − k n - ik - k n−ik−k);
[3] 在剩下的 n − i k − k − 1 n - ik - k - 1 n−ik−k−1 点中选一个子集用 m − 1 m - 1 m−1 种颜色染色,相当于将这 n − i k − 1 n - ik - 1 n−ik−1 个点用 m m m 种颜色染色( m n − i k − k − 1 m^{n - ik - k - 1} mn−ik−k−1); - 关键点在
k
k
k 个点之中:
[1] 在 k k k 个点中选一个关键点( k k k);
[2] 将其余 k − 1 k - 1 k−1 个点用 ( m − 1 ) (m - 1) (m−1) 种颜色染色( ( m − 1 ) k − 1 (m - 1)^{k - 1} (m−1)k−1);
[3] 在剩下的 n − i k − k n - ik - k n−ik−k 个点中选一个子集用 m − 1 m - 1 m−1 种颜色染色,计算与上面同理( m n − i k − k m^{n - ik - k} mn−ik−k)
- 关键点不在
k
k
k 个点之中:
- 最终整理得到答案的式子为: n × m n − m ∑ i = 0 n − 1 ∑ k = 0 n ( − 1 ) k n − i k C n − i k k [ ( n − i k − k ) ( m − 1 ) k m n − i k − k − 1 + k ( m − 1 ) k − 1 m n − i k − k ] n \times m^{n} - m\sum \limits_{i = 0}^{n - 1} \sum \limits_{k = 0}^{n} \frac{(-1)^k}{n - ik} C_{n - ik}^{k}[(n - ik - k)(m - 1)^k m^{n - ik - k -1} + k(m - 1)^{k - 1} m^{n - ik - k}] n×mn−mi=0∑n−1k=0∑nn−ik(−1)kCn−ikk[(n−ik−k)(m−1)kmn−ik−k−1+k(m−1)k−1mn−ik−k]
- 预处理 m , m − 1 m, m - 1 m,m−1 的次幂以及 1 1 1~ n n n 阶乘、 1 1 1~ n n n 逆元、 1 1 1~ n n n 阶乘逆元,即可在 O ( n log n ) O(n \log n) O(nlogn) 的复杂度内计算上述式子。
Code
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cctype>
#include <cstdio>
#include <cmath>
#include <ctime>
template <class T>
inline void read(T &res)
{
char ch; bool flag = false; res = 0;
while (ch = getchar(), !isdigit(ch) && ch != '-');
ch == '-' ? flag = true : res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + ch - 48;
flag ? res = -res : 0;
}
const int N = 3e5 + 5;
int fra[N], inv_fra[N], inv[N], ex_m[N], ex_m1[N];
int n, m, mod;
inline void add(int &x, int y)
{
x += y;
x >= mod ? x -= mod : 0;
}
inline int quick_pow(int x, int k)
{
int res = 1;
while (k)
{
if (k & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod; k >>= 1;
}
return res;
}
inline int C(int n, int m)
{
return 1ll * fra[n] * inv_fra[n - m] % mod * inv_fra[m] % mod;
}
int main()
{
freopen("sequence.in", "r", stdin);
freopen("sequence.out", "w", stdout);
read(n); read(m); read(mod);
ex_m[0] = ex_m1[0] = fra[0] = 1;
for (int i = 1; i <= n; ++i)
{
fra[i] = 1ll * fra[i - 1] * i % mod;
ex_m[i] = 1ll * ex_m[i - 1] * m % mod;
ex_m1[i] = 1ll * ex_m1[i - 1] * (m - 1) % mod;
}
inv_fra[n] = quick_pow(fra[n], mod - 2);
for (int i = n; i >= 1; --i)
inv_fra[i - 1] = 1ll * inv_fra[i] * i % mod;
inv[0] = 1;
for (int i = 1; i <= n; ++i)
inv[i] = 1ll * fra[i - 1] * inv_fra[i] % mod;
int ans = 0;
for (int i = 0; i < n; ++i)
for (int k = 0, km = n / (i + 1); k <= km; ++k)
{
int t = n - i * k - k,
sit1 = 1ll * k * (k > 0 ? ex_m1[k - 1] : 0) % mod * ex_m[t] % mod,
sit2 = 1ll * t * ex_m1[k] % mod * (t > 0 ? ex_m[t - 1] : 0) % mod;
int tmp = 1ll * inv[t + k] * C(t + k, k) % mod * (sit1 + sit2) % mod;
add(ans, (k & 1) ? mod - tmp : tmp);
}
ans = 1ll * (mod - m) * ans % mod;
add(ans, 1ll * n * ex_m[n] % mod);
printf("%d\n", ans);
fclose(stdin); fclose(stdout);
return 0;
}