考虑 dp 求解,dp 构造升序序列,最后乘上 n ! n! n!
令 f [ i ] [ j ] f[i][j] f[i][j] 表示前 i i i 个位置,最大值小于等于 j j j 的贡献,转移方程: f [ i ] [ j ] = f [ i − 1 ] [ j − 1 ] ∗ j + f [ i ] [ j − 1 ] f[i][j] = f[i - 1][j - 1] * j + f[i][j - 1] f[i][j]=f[i−1][j−1]∗j+f[i][j−1]
最终答案是
f
[
n
]
[
k
]
f[n][k]
f[n][k] ,k 非常大肯定无法求解,考虑优化:
令
g
[
i
]
[
j
]
g[i][j]
g[i][j] 表示前
i
i
i 个位置,第
i
i
i 个放
j
j
j 的贡献,转移方程:
g
[
i
]
[
j
]
=
j
∗
∑
k
=
1
j
−
1
g
[
i
−
1
]
[
k
]
\displaystyle g[i][j] = j*\sum_{k = 1}^{j - 1}g[i - 1][k]
g[i][j]=j∗k=1∑j−1g[i−1][k]
显然有
f
[
n
]
[
k
]
=
∑
i
=
1
k
g
[
n
]
[
k
]
\displaystyle f[n][k] = \sum_{i = 1}^kg[n][k]
f[n][k]=i=1∑kg[n][k]
如果可以证明 g [ n ] [ k ] g[n][k] g[n][k] 是一个以 k k k 为自变量的多项式,就可以使用拉格朗日插值快速求解
使用归纳法证明:
当 n = 0 时,g[0][k] = 0,结论成立
设 n > 0 且,g[n][k] 是一个以 k 为自变量的多项式
根据转移方程,有:
g
[
n
+
1
]
[
k
]
=
k
∗
∑
i
=
1
k
−
1
g
[
n
]
[
i
]
\displaystyle g[n + 1][k] =k*\sum_{i = 1}^{k - 1}g[n][i]
g[n+1][k]=k∗i=1∑k−1g[n][i],根据k次幂和的推论,可以得知
g
[
n
+
1
]
[
k
]
g[n + 1][k]
g[n+1][k] 是以
k
k
k 为自变量的多项式,当
n
=
0
n = 0
n=0 时
d
p
[
n
]
[
k
]
dp[n][k]
dp[n][k] 是一个
0
0
0 次多项式,
n
n
n 每增一,根据转移方程可以得出 多项式的次数
+
2
+ 2
+2,因此
d
p
[
n
]
[
k
]
dp[n][k]
dp[n][k] 是一个以
k
k
k 为自变量的
2
n
2n
2n 次多项式。
因此 f [ n ] [ k ] f[n][k] f[n][k] 是一个以 k 为自变量的 2 n + 1 2n + 1 2n+1 次多项式,求出 2 n + 2 2n + 2 2n+2 个点,通过插值快速求出最终答案。不要忘了乘上 n ! n! n!
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 4e3 + 10;
typedef long long ll;
int mod,mx,n,k;
ll fac[maxn],ifac[maxn];
inline ll add(ll x, ll y) {
x += y;
if (x >= mod) x -= mod;
return x;
}
inline ll sub(ll x, ll y) {
x -= y;
if (x < 0) x += mod;
return x;
}
inline ll mul(ll x, ll y) {
return x * y % mod;
}
ll fpow(ll a,ll b) {
ll r = 1;
while(b) {
if (b & 1) r = mul(r,a);
b >>= 1;
a = mul(a,a);
}
return r;
}
ll cal(ll g[maxn],ll x) { //拉格朗日插值计算多项式
if (x <= mx) return g[x];
ll tmp = 1,inv,ans = 0;
for (int i = 1; i <= mx; i++)
tmp = mul(tmp,x - i);
for (int i = 1; i <= mx; i++) {
ll res = 1, inv = fpow(x - i,mod - 2);
res = mul(res,g[i]);
res = mul(res,ifac[i - 1]);
res = mul(res,ifac[mx - i]);
res = mul(res,inv);
res = mul(res,tmp);
if ((mx - i) & 1) res = mul(res,-1);
if (res < 0) res += mod;
ans = add(ans,res);
}
return ans;
}
ll f[maxn],tp[maxn],dp[2000][2000];
int main() {
scanf("%d%d%d",&k,&n,&mod);
fac[0] = 1;
for (int i = 1; i <= 4000; i++)
fac[i] = mul(fac[i - 1],i);
ifac[4000] = fpow(fac[4000],mod - 2);
for (int i = 4000 - 1; i >= 0; i--)
ifac[i] = mul(ifac[i + 1],i + 1);
mx = 2 * n + 4;
for (int j = 0; j <= mx; j++)
dp[0][j] = 1;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= mx; j++)
dp[i][j] = (1ll * dp[i - 1][j - 1] * j + dp[i][j - 1]) % mod;
printf("%lld\n",1ll * cal(dp[n],k) * fac[n] % mod);
return 0;
}