题目大意:有n个数字一开始全0,每次随机一个数字++,问期望多少步后第一次有个位置的数字的值是k。
n
≤
50
,
k
≤
1000
n\le50,k\le1000
n≤50,k≤1000。
题解:
显然k=1是个min-max容斥。因此min-max容斥。假设枚举的集合大小是
a
a
a。
一开始场上的做法是,期望转为
∑
i
≥
1
P
(
a
n
s
≥
i
)
\sum_{i\ge1}P(ans\ge i)
∑i≥1P(ans≥i),然后
P
(
a
n
s
≥
i
)
P(ans\ge i)
P(ans≥i)就是说到
i
−
1
i-1
i−1的时候还没好的概率,然后再枚举这i-1次操作中多少次落在左半边,整理式子后发现有个组合数数列点积等比数列的求和,总之推导一波后发现要计算大小为a的集合i步后不存在>=k的数值的方案数,发现这玩意只能
n
2
k
2
n^2k^2
n2k2dp,但是是卷积,场上写个NTT过了。
另一种做法是直接枚举:
∑
i
≥
1
P
(
a
n
s
=
i
)
i
\sum_{i\ge 1}P(ans=i)i
∑i≥1P(ans=i)i,然后还是枚举左边的次数j,右边还是组合数数列点积等比数列的求和(其实算出来就是
n
a
\frac na
an),推导一下发现要算大小为a的集合i步后恰好有个位置是k的方案数,这个就用钦定某个数第i步是出现恰好k次的总方案数减去不合法的位置。后者用个前缀和即可。其实就是推式子很麻烦,但直接冷静思考一波就是a的答案乘以
n
a
\frac na
an的比率即可。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define p 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
const int N=52,K=1002;
int f[2][N*K],fac[N*K],facinv[N*K],inv[N*K],mi1[N*K],mi2[N*K];
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%p) (k&1)?ans=(lint)ans*x%p:0;return ans; }
inline int C(int n,int m) { return assert(n>=m&&m>=0),(lint)fac[n]*facinv[m]%p*facinv[n-m]%p; }
inline int prelude(int n)
{
rep(i,fac[0]=1,n) fac[i]=(lint)i*fac[i-1]%p;
facinv[n]=fast_pow(fac[n],p-2);
for(int i=n-1;i>=0;i--) facinv[i]=(i+1ll)*facinv[i+1]%p;
rep(i,1,n) inv[i]=(lint)fac[i-1]*facinv[i]%p;
return 0;
}
int main()
{
int n=inn(),k=inn(),ans=0;prelude(n*k);
rep(i,1,n)
{
int *now=f[i&1],*pre=f[(i-1)&1],s=0;
rep(j,k,(i-1)*(k-1)+1) pre[j]=(pre[j]+(i-1ll)*pre[j-1])%p;
rep(j,mi1[0]=1,i*(k-1)+1) mi1[j]=mi1[j-1]*(i-1ll)%p;
rep(j,mi2[0]=1,i*(k-1)+2) mi2[j]=(lint)mi2[j-1]*inv[i]%p;
rep(j,k,i*(k-1)+1)
now[j]=(lint)i*C(j-1,k-1)%p*(mi1[j-k]-pre[j-k])%p,
(now[j]<0?now[j]+=p:0),s=(s+(lint)now[j]*j%p*mi2[j+1])%p;
if(i&1) ans=(ans+(lint)C(n,i)*s)%p;
else ans=(ans-(lint)C(n,i)*s)%p,(ans<0?ans+=p:0);
}
return !printf("%lld\n",(lint)ans*n%p);
}
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define p 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
const int N=52,K=1005,QWQ=67000;
int dp[N][QWQ],fac[N*K],facinv[N*K],mi[N*K],tmp[QWQ];
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%p) (k&1)?ans=(lint)ans*x%p:0;return ans; }
inline int sol(int x,int s) { return (s&1)?(x?p-x:0):x; }
inline int C(int n,int m) { return (lint)fac[n]*facinv[m]%p*facinv[n-m]%p; }
namespace NTT_space{
const int N=67000;
int r[N],*dwg[N],*dwgi[N];
inline int prelude_dwg()
{
int n=N-2;
for(int i=2,t=1;i<=n;i<<=1,t++)
{
dwg[t]=new int[i],dwgi[t]=new int[i];
int *d=dwg[t],*di=dwgi[t];d[0]=di[0]=1;
int w=fast_pow(3,(p-1)/i),wi=fast_pow(3,p-1-(p-1)/i);
rep(j,1,i-1) d[j]=(lint)d[j-1]*w%p,di[j]=(lint)di[j-1]*wi%p;
}
return 0;
}
inline int pre(int m)
{
int n=1,L=0;
while(n<=m) n<<=1,L++;
rep(i,1,n-1) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
return n;
}
inline int NTT(int *a,int n,int s)
{
rep(i,1,n-1) if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=2,c=1;i<=n;i<<=1,c++)
{
int *d=(s>0?dwg[c]:dwgi[c]);
for(int j=0,t=i>>1,x,y;j<n;j+=i) rep(k,0,t-1)
x=a[j+k],y=(lint)d[k]*a[j+k+t]%p,
a[j+k]=x+y,(a[j+k]>=p?a[j+k]-=p:0),
a[j+k+t]=x-y,(a[j+k+t]<0?a[j+k+t]+=p:0);
}
if(s<0) for(int i=0,v=fast_pow(n,p-2);i<n;i++) a[i]=(lint)a[i]*v%p;
return 0;
}
}using NTT_space::NTT;
inline int prelude(int n,int k)
{
dp[0][0]=1;int m=max(n,n*(k-1));
rep(i,fac[0]=1,m) fac[i]=(lint)fac[i-1]*i%p;
facinv[m]=fast_pow(fac[m],p-2);
for(int i=m-1;i>=0;i--) facinv[i]=(i+1ll)*facinv[i+1]%p;
int t=NTT_space::pre(n*(k-1));
NTT(dp[0],t,1);rep(i,0,k-1) tmp[i]=facinv[i];NTT(tmp,t,1);
rep(i,1,n) rep(j,0,t-1) dp[i][j]=(lint)dp[i-1][j]*tmp[j]%p;
rep(i,1,n) NTT(dp[i],t,-1);
rep(i,1,n) rep(j,0,i*(k-1)) dp[i][j]=(lint)dp[i][j]*fac[j]%p;
return 0;
}
int main()
{
NTT_space::prelude_dwg();
int n=inn(),k=inn(),ans=0;prelude(n,k);
rep(i,1,n)
{
int s=0,v=fast_pow(i,p-2);
rep(j,mi[0]=1,i*(k-1)+1)
mi[j]=(lint)mi[j-1]*v%p;
rep(j,0,i*(k-1))
s=(s+(lint)mi[j+1]*dp[i][j])%p;
ans+=sol((lint)C(n,i)*s%p,i+1);
if(ans>=p) ans-=p;
}
return !printf("%lld\n",(lint)ans*n%p);
}