题目链接:火柴排队
题意
给你长度为n的序列a,随机使得k个元素增加d,求增加后的序列 a i ′ {a_i'} ai′满足:如果 a i < a j , 那 么 a i ′ < a j ′ {a_i<a_j,那么a_i'<a_j'} ai<aj,那么ai′<aj′的概率为多少。对于每个1≤K≤n,你要输出每个答案并对998244353取模。
题解
本题发现n≤5000,并且每个状态都和前一个状态有关系,所以我们可以尝试用dp求解。
分析本题,对序列的顺序并无要求,所以可以先对序列排序,这样对一个元素加d后,只需判断和后面的关系即可。
-
定义:dp[i][j][k]:前i个数已经选择j个,并且当前这个数是否选择,满足题意的情况有多少种。
(k为0代表当前这个数未选,1代表当前这个数已经被选中。) -
初始化:dp[1][0][0]=dp[1][1][1]=1;
-
状态转移方程:
(如果未选a[i],那么前i-1个数必定已经选了j个数,如果a[i-1]已经被选择,那么需要判断(a[i-1]+d)后与a[i]的关系再选择是否加上该值。如果未选中,直接加上即可。)
d p [ i ] [ j ] [ 0 ] = d p [ i − 1 ] [ j ] [ 0 ] + d p [ i − 1 ] [ j ] [ 1 ] ∗ ( a [ i − 1 ] + d ≤ a [ i ] ) {dp[i][j][0]=dp[i-1][j][0]+dp[i-1][j][1]*(a[i-1]+d≤a[i])} dp[i][j][0]=dp[i−1][j][0]+dp[i−1][j][1]∗(a[i−1]+d≤a[i])
(如果已选a[i],那么前i-1个数必定只选了(j-1)个数。如果a[i-1]已被选择,由于a[i]也被选,所以同时+d,不会改变顺序。所以两个状态都加上)
d p [ i ] [ j ] [ 1 ] = d p [ i − 1 ] [ j − 1 ] [ 0 ] + d p [ i − 1 ] [ j − 1 ] [ 1 ] {dp[i][j][1]=dp[i-1][j-1][0]+dp[i-1][j-1][1]} dp[i][j][1]=dp[i−1][j−1][0]+dp[i−1][j−1][1]
由于n≤5000,dp[5000][5000][2],虽然理论上不会超,但是保险期间还是用滚动数组优化,状态只在i和(i-1)之间转化,这样数组大小可以开为dp[2][5000][2]。
这样我们就可以在 O ( n 2 ) {O(n^2)} O(n2)的时间复杂度内知道,选择i个元素增加d后符合题意的情况个数:dp[n][i][0]+dp[n][i][1]
总的情况种类数排列组合也能得出: C n i {C_n^i} Cni
求排列组合可通过递推求解,
C
n
k
=
n
−
k
+
1
k
C
n
k
−
1
{C_n^k=\frac{n-k+1}{k}C_n^{k-1}}
Cnk=kn−k+1Cnk−1
所以我们也可以在O(n)的时间复杂度内求出
C
n
0
,
C
n
1
,
.
.
.
.
.
C
n
n
{C_n^0,C_n^1,.....C_n^n}
Cn0,Cn1,.....Cnn
这样易得出概率为: d p [ n ] [ i ] [ 0 ] + d p [ n ] [ i ] [ 1 ] C n i {\frac{dp[n][i][0]+dp[n][i][1]}{C_n^i}} Cnidp[n][i][0]+dp[n][i][1]。
由于取模求逆元,模值为质数,可用费马小定理求得逆元inv(a)=pow(a,mod-2)。
代码
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<bitset>
#include<cassert>
#include<cctype>
#include<cmath>
#include<cstdlib>
#include<ctime>
#include<deque>
#include<iomanip>
#include<list>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
using namespace std;
//extern "C"{void *__dso_handle=0;}
typedef long long ll;
typedef long double ld;
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define lowbit(x) x&-x
const double PI=acos(-1.0);
const double eps=1e-6;
const ll mod=998244353;
const int inf=0x3f3f3f3f;
const int maxn=5e3+10;
const int maxm=100+10;
#define ios ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
ll a[maxn],c[maxn];
ll qpow(ll a,ll b)
{
ll ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return (ans+mod)%mod;
}
ll inv(ll a) { return qpow(a, mod-2); }
void getcc(ll n)
{
c[0]=1;
for(int i=1;i<=n;i++) c[i]=(n-i+1)*c[i-1]%mod*inv(i)%mod;
}
ll dp[2][maxn][2];
int main()
{
ll n,d;
scanf("%lld%lld",&n,&d);
getcc(n);
for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
sort(a+1,a+1+n);
dp[1][1][1]=dp[1][0][0]=1;
for(int i=2;i<=n;i++)
for(int j=0;j<=i;j++)
{
dp[i&1][j][0]=(dp[(i-1)&1][j][0]+dp[(i-1)&1][j][1]*((a[i-1]+d)<=a[i]))%mod;
if(j-1>=0) dp[i&1][j][1]=(dp[(i-1)&1][j-1][1]+dp[(i-1)&1][j-1][0])%mod;
}
for(int i=1;i<=n;i++)
{
ll ans=(((dp[n&1][i][0]+dp[n&1][i][1])%mod)*qpow(c[i], mod-2)%mod+mod)%mod;
printf("%lld\n",ans);
}
}