题目有点绕,原文大意是指:地上有 N个黄金袋,袋子重量是一个[1,n]的排列,有一个人在捡黄金,初始时他没有黄金,当他开始捡黄金时,如果当前地上的黄金比他的已有的重,他会捡起地上的丢掉已有的。
然后已知是如果他从前往后捡他会捡A次,从后往前捡他会捡B次,问有多少个排列
题意转化过来就是,从前往后遍历有A个元素是前缀最大值,从后往前遍历有B个元素是后缀最大值。问这样的排列有多少种。
首先要考虑到若排列是一个符合条件的排列,整个排列的最大值,一定会将整个排列分成两部分,即前面一部分有
a
−
1
a - 1
a−1个前缀最大值,后面一部分有
b
−
1
b - 1
b−1个后缀最大值。
这样前后两个问题是相同的,只是颠倒了。
令
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 表示
i
i
i 个数构成的排列里有
j
j
j 个前缀最大值的方案个数。
枚举这 i 个数里的最小值所在的位置,可以得到一个递推式:
即: d p [ i ] [ j ] = d p [ i − 1 ] [ j − 1 ] + ( i − 1 ) ∗ d p [ i − 1 ] [ j ] dp[i][j] = dp[i - 1][j - 1] + (i - 1) *dp[i - 1][j] dp[i][j]=dp[i−1][j−1]+(i−1)∗dp[i−1][j]
最小值要么放在最前面,此时另外 i − 1 i - 1 i−1 个数字构成的排列必须有 j − 1 j - 1 j−1个前缀最大值。如果最小值不放在最前面,那么它可以放在其它 i − 1 i - 1 i−1个位置中的任意一个,这时最小值没有贡献,因此另外 i − 1 i - 1 i−1个数字必须要有 j j j 个前缀最大值。
边界值为 d p [ n ] [ 0 ] = 0 , d p [ n ] [ n ] = 1 dp[n][0] = 0,dp[n][n] = 1 dp[n][0]=0,dp[n][n]=1,会发现这就是第一类斯特林数。
那么总的答案为 a n s = ∑ i = 0 n − 1 d p [ i ] [ a − 1 ] ∗ d p [ n − 1 − i ] [ b − 1 ] ∗ C ( n − 1 , i ) ans = \sum_{i = 0}^{n-1}dp[i][a - 1]*dp[n-1-i][b-1]*C(n-1,i) ans=∑i=0n−1dp[i][a−1]∗dp[n−1−i][b−1]∗C(n−1,i)
由于 d p [ i ] [ j ] dp[i][j] dp[i][j] 就是斯特林数,斯特林数的组合意义是: s ( i , j ) : s(i,j): s(i,j): i i i 个元素构成j个圆排列的方案数。
从组合意义上看待这个答案式子,它表达的是从 n − 1 n - 1 n−1 个元素中选 i i i个构成前 a − 1 a - 1 a−1 个圆排列,剩下的构成后 b − 1 b - 1 b−1 个圆排列,并且对这些方案求和。
这等价于直接用 n − 1 n - 1 n−1 个元素构成 a + b − 2 a + b - 2 a+b−2个圆排列,其中有 a − 1 a - 1 a−1个圆排列在前面, b − 1 b - 1 b−1个圆排列在后面。
那么答案式子转化为: a n s = d p [ n − 1 ] [ a + b − 2 ] ∗ C ( a + b − 2 , a − 1 ) ans = dp[n - 1][a + b - 2] * C(a + b - 2,a - 1) ans=dp[n−1][a+b−2]∗C(a+b−2,a−1)
通过递推式求出 d p [ n − 1 ] [ a + b − 2 ] dp[n - 1][a + b - 2] dp[n−1][a+b−2] 需要 n 2 n^2 n2的时间。
对第一类斯特林数
s
(
n
,
k
)
s(n,k)
s(n,k) 可以构造一个生成函数:
∏
i
=
0
n
−
1
(
x
+
i
)
\prod_{i = 0}^{n - 1}(x + i)
∏i=0n−1(x+i)
x
k
x^k
xk项的系数值就是
s
(
n
,
k
)
s(n,k)
s(n,k)
如何理解(或者说为什么是这个生成函数):用
f
n
(
x
)
f_n(x)
fn(x)表示这个生成函数
那么有
f
n
(
x
)
=
f
n
−
1
(
x
)
∗
(
x
+
n
−
1
)
=
x
f
n
−
1
(
x
)
+
(
n
−
1
)
f
n
−
1
(
x
)
f_n(x) = f_{n - 1}(x)*(x + n - 1) =xf_{n - 1}(x) + (n - 1)f_{n - 1}(x)
fn(x)=fn−1(x)∗(x+n−1)=xfn−1(x)+(n−1)fn−1(x)
这个生成函数乘的过程就是第一类斯特林数递推的过程。
多个多项式相乘可以用分治NTT,复杂度
O
(
n
log
2
n
)
O(n \log^2 n)
O(nlog2n)
这题就做完了
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 10;
const int mod = 998244353;
typedef long long ll;
vector<ll> g[maxn << 2];
ll A[maxn],B[maxn];
ll fact[maxn],ifact[maxn];
ll fpow(ll a,ll b) {
ll r = 1;
while(b) {
if(b & 1) r = r * a % mod;
a = a * a % mod;
b >>= 1;
}
return r;
}
void change(ll t[],int len) {
for(int i = 1, j = len / 2; i < len - 1; i++) {
if(i < j) swap(t[i],t[j]);
int k = len / 2;
while(j >= k) {
j -= k;
k /= 2;
}
if(j < k) j += k;
}
}
void NTT(ll t[],int len,int type) {
change(t,len);
for(int s = 2; s <= len; s <<= 1) {
ll wn = fpow(3,(mod - 1) / s);
if(type == -1) wn = fpow(wn,mod - 2);
for(int j = 0; j < len; j += s) {
ll w = 1;
for(int k = 0; k < s / 2; k++) {
ll u = t[j + k],v = t[j + k + s / 2] * w % mod;
t[j + k] = (u + v) % mod;
t[j + k + s / 2] = (u - v + mod) % mod;
w = w * wn % mod;
}
}
}
if(type == -1) {
ll inv = fpow(len,mod - 2);
for(int i = 0; i < len; i++)
t[i] = t[i] * inv % mod;
}
}
ll n,a,b;
void solve(int rt,int l,int r) {
if(l == r) {
g[rt].push_back(l);
g[rt].push_back(1);
return;
}
int mid = l + r >> 1;
solve(rt << 1,l,mid);
solve(rt << 1 | 1,mid + 1,r);
int len = 1;
while(len <= r - l + 1) len <<= 1;
for(int i = 0; i < len; i++)
A[i] = B[i] = 0;
for(int i = 0; i <= mid - l + 1; i++)
A[i] = g[rt << 1][i];
for(int i = 0; i <= r - mid; i++)
B[i] = g[rt << 1 | 1][i];
NTT(A,len,1);NTT(B,len,1);
for(int i = 0; i < len; i++)
A[i] = A[i] * B[i] % mod;
NTT(A,len,-1);
for(int i = 0; i < len; i++)
g[rt].push_back(A[i]);
g[rt << 1].clear();g[rt << 1 | 1].clear();
}
ll C(int x,int y) {
if(y > x || y < 0) return 0;
return fact[x] * ifact[y] % mod * ifact[x - y] % mod;
}
int main() {
fact[0] = 1;
for(int i = 1; i < maxn; i++)
fact[i] = fact[i - 1] * i % mod;
ifact[maxn - 1] = fpow(fact[maxn - 1],mod - 2);
for(int i = maxn - 2; i >= 0; i--)
ifact[i] = ifact[i + 1] * (i + 1) % mod;
scanf("%d%d%d",&n,&a,&b);
if(a == 0 || b == 0 || a + b - 2 > n - 1)
puts("0");
else if(n == 1) puts("1");
else {
solve(1,0,n - 2);
ll res = g[1][a + b - 2] * C(a + b - 2,a - 1) % mod;
printf("%lld\n",res);
}
return 0;
}