有n个数字,每单位时间会出现一个数字,第i个数字有
p
i
m
\frac{p_i}{m}
mpi的概率出现,并且
∑
i
=
1
n
p
i
=
m
\sum_{i=1}^n p_i=m
∑i=1npi=m,求出现了k个数字的时间的期望。
n
≤
1000
,
m
≤
10000
,
n
−
k
≤
10
n\le1000, m\le 10000,n-k\le10
n≤1000,m≤10000,n−k≤10
这个玄学数据范围可海星,可以做到O(nm(n-k))或者O(nmk)。
首先考虑这个问题,询问的k,等价于询问集合的第n-k+1大。
现在考虑转化后的问题,也就是要找到一个函数F,使得下式成立:
∑
T
⊆
S
m
i
n
(
T
)
F
(
∣
T
∣
)
=
k
\sum_{T\subseteq S}min(T)F(|T|)=k
∑T⊆Smin(T)F(∣T∣)=k,这里的k已经是刚刚的n-k+1了。
min(T)在这个题目里面就是m/Sum(T)。
考虑一般Min-Max容斥的证明,对于第x大的元素,其被计算系数应该为[x==n-k+1]:
[
x
=
n
−
k
+
1
]
=
∑
T
=
0
n
−
x
(
n
−
x
T
)
F
(
∣
T
∣
+
1
)
[x=n-k+1]=\sum_{T=0}^{n-x}\binom{n-x}{T}F(|T|+1)
[x=n−k+1]=∑T=0n−x(Tn−x)F(∣T∣+1)
我们希望构造一个函数使得这个式子成立,也就是当x< k的时候不会被计算,当x>k的时候会被容斥掉,那么我们取:
F
(
x
)
=
(
−
1
)
x
−
k
(
x
−
1
k
−
1
)
F(x)=(-1)^{x-k}\binom{x-1}{k-1}
F(x)=(−1)x−k(k−1x−1)
带回原式:
[
x
=
k
]
=
∑
T
=
0
n
−
x
(
n
−
x
T
)
(
−
1
)
T
+
1
−
k
(
T
k
−
1
)
[x=k]=\sum_{T=0}^{n-x}\binom{n-x}{T}(-1)^{T+1-k}\binom{T}{k-1}
[x=k]=∑T=0n−x(Tn−x)(−1)T+1−k(k−1T)
这样,当n-x< k-1时,T< k-1,后面那个组合数直接是0;否则,当n-x>=k的时候,相当于是先从n-x个元素中选出T的,然后从其中选出k-1个;换句话说就是先从n-x个中选出k-1个,然后剩下的n-x-k+1个元素再任意选择,并且根据这些元素的个数决定系数的正负,显然后半部分只有当n-x-k+1=0的时候系数是1,此时两个组合数也是1,这样就得证了。
我们带回最最开始的式子:
∑
T
⊆
S
m
S
u
m
(
T
)
(
−
1
)
∣
T
∣
−
k
(
∣
T
∣
−
1
k
−
1
)
\sum_{T\subseteq S}\frac{m}{Sum(T)}(-1)^{|T|-k}\binom{|T|-1}{k-1}
∑T⊆SSum(T)m(−1)∣T∣−k(k−1∣T∣−1)
注意到k比较小,我们设dp(i, j, t)表示前i个数字,选出来的数字之和是t,前面的系数之和在k=j的时候是多少,转移显然。
如果k比较大,那么因为|T|>=k所以直接设dp(i, j, t)表示前i个数字删去j个,剩下的数字之和是多少。
总之复杂度如上文。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<assert.h>
#define N 1010
#define M 10010
#define NMK 15
#define mod 998244353
#define lint long long
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
int dp[2][NMK][M],p[N],fac[M],facinv[M],inv[M];
inline int sol(lint x,int s) { return x%=mod,((s&1)?(mod-x)%mod:x); }
inline int fast_pow(int x,int k,int ans=1)
{ for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans; }
inline int prelude(int n)
{
for(int i=fac[0]=1;i<=n;i++) fac[i]=(lint)fac[i-1]*i%mod;
facinv[n]=fast_pow(fac[n],mod-2);
for(int i=n-1;i>=0;i--) facinv[i]=facinv[i+1]*(i+1ll)%mod;
for(int i=1;i<=n;i++) inv[i]=(lint)fac[i-1]*facinv[i]%mod;
return 0;
}
int main()
{
int n,k,m,ans=0,now,pre;scanf("%d%d%d",&n,&k,&m);
rep(i,1,n) scanf("%d",&p[i]);prelude(max(n,m));
pre=0,now=1,k=n-k+1;rep(i,1,k) dp[pre][i][0]=-1;
for(int i=1;i<=n;i++,swap(now,pre))
{
rep(j,1,k) memcpy(dp[now][j],dp[pre][j],sizeof(int)*(m+1));
rep(j,1,k) rep(t,p[i],m)
(dp[now][j][t]+=mod-dp[pre][j][t-p[i]])%=mod,
(dp[now][j][t]+=dp[pre][j-1][t-p[i]])%=mod;
}
swap(now,pre);
rep(i,1,m) if(dp[now][k][i]) (ans+=(lint)dp[now][k][i]*inv[i]%mod)%=mod;
return !printf("%lld\n",(lint)ans*m%mod);
}