Solution \text{Solution} Solution
神奇题目。
首先可以强制所有的数递增,最后的答案乘一个
n
!
n!
n! 即可。
设
d
p
i
,
j
dp_{i,j}
dpi,j 表示在
[
1
,
j
]
[1,j]
[1,j] 的值域选了
i
i
i 个数的答案,不难写出 dp 转移:
d
p
i
,
j
=
d
p
i
−
1
,
j
−
1
×
j
+
d
p
i
,
j
−
1
dp_{i,j}=dp_{i-1,j-1}\times j+dp_{i,j-1}
dpi,j=dpi−1,j−1×j+dpi,j−1
答案就是
d
p
n
,
k
dp_{n,k}
dpn,k。
直接暴力做是
O
(
n
k
)
O(nk)
O(nk) 的,无法通过。
考虑使用拉格朗日插值优化。
既然要用拉格朗日插值,关键就在与证明
d
p
n
,
k
dp_{n,k}
dpn,k 是一个以
k
k
k 为自变量的
f
n
f_n
fn 次多项式。
首先又一个较为显然的结论,若
g
(
x
)
g(x)
g(x) 是一个
k
k
k 次多项式,那么它的差分
g
(
x
)
−
g
(
x
−
1
)
g(x)-g(x-1)
g(x)−g(x−1) 就是一个
k
−
1
k-1
k−1 次多项式。
那么回到刚才的转移式,它也可以写成:
d
p
i
,
j
−
d
p
i
,
j
−
1
=
d
p
i
−
1
,
j
−
1
×
j
dp_{i,j}-dp_{i,j-1}=dp_{i-1,j-1}\times j
dpi,j−dpi,j−1=dpi−1,j−1×j
考虑多项式次数,也就是:
f
n
−
1
=
f
n
−
1
+
1
f_n-1=f_{n-1}+1
fn−1=fn−1+1
也就是说
f
n
f_n
fn 是一个公差为二的等差数列。
又因为有:
d
p
n
,
0
=
0
,
f
0
=
0
dp_{n,0}=0,f_0=0
dpn,0=0,f0=0,所以就能得到:
f
n
=
2
n
f_n=2n
fn=2n
O
(
n
2
)
O(n^2)
O(n2) 暴力求出前
n
n
n 项插值即可,连续值域插值可以前缀和优化到线性。
总复杂度
O
(
n
2
)
O(n^2)
O(n2)。
Code \text{Code} Code
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define debug(...) fprintf(stderr,__VA_ARGS__)
inline ll read(){
ll x(0),f(1);char c=getchar();
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
const int N=2050;
int mod;
ll n,m;
inline ll ksm(ll x,ll k){
ll res(1);
while(k){
if(k&1) res=x*res%mod;
x=x*x%mod;
k>>=1;
}
return res;
}
ll x[N],y[N];
ll jc[N],suf[N],pre[N],ni[N];
ll lagrange(int n,ll *y,ll k){//consecutive
k%=mod;
jc[0]=1;
for(int i=1;i<=n;i++) jc[i]=jc[i-1]*i%mod;
ni[n]=ksm(jc[n],mod-2);
for(int i=n-1;i>=0;i--) ni[i]=ni[i+1]*(i+1)%mod;
pre[0]=1;
for(int i=1;i<=n;i++) pre[i]=pre[i-1]*(k-i)%mod;
suf[n+1]=1;
for(int i=n;i>=1;i--) suf[i]=suf[i+1]*(k-i)%mod;
ll res(0);
for(int i=1;i<=n;i++){
ll add=y[i]*pre[i-1]%mod*suf[i+1]%mod*ni[i-1]%mod*ni[n-i]%mod;
if((n-i)&1) add=mod-add;
(res+=add)%=mod;
}
return res;
}
ll dp[505][1505];
signed main(){
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
#endif
m=read();n=read();mod=read();
for(int i=0;i<=2*n+1;i++) dp[0][i]=1;
for(int i=1;i<=n;i++){
for(int j=1;j<=n*2+1;j++){
dp[i][j]=(dp[i][j-1]+dp[i-1][j-1]*j)%mod;
}
}
for(int i=1;i<=2*n+1;i++){
y[i]=dp[n][i];
}
ll res=lagrange(2*n+1,y,m);
printf("%lld\n",res*jc[n]%mod);
return 0;
}
/*
*/