CF1608F MEX counting
题解
看到这个题的第一眼以为是排列?思考这不就一个 O ( n k ) O(nk) O(nk) DP 还开4秒?
原来是我傻逼了。
由于数的是序列不是排列,所以我们不能够通过当前位置的 M e x Mex Mex 值来判断前面未确定值的位置的个数。
我们可以先设 d p [ i ] [ j ] [ k ] dp[i][j][k] dp[i][j][k] 表示考虑前 i i i 个位置,当前的 M e x Mex Mex 值为 j j j,有 k k k 个位置值未确定的方案数。此时第一维大小 O ( n ) O(n) O(n),第三维大小 O ( n ) O(n) O(n)。第二维大小考虑到转移的需要,对单个 i i i 的大小可能达到 O ( n ) O(n) O(n),然而由于 M e x Mex Mex 值单增,所以我们弃掉没用状态后可以达到均摊 O ( k ) O(k) O(k)。
总状态数 O ( n 2 k ) O(n^2k) O(n2k),貌似可过,但是由于转移的时候需要确定有多少个位置取到某个值以及排列情况,所以转移并不能达到均摊 O ( 1 ) O(1) O(1)。
上面的状态定义复杂度不优的原因是,我们把太多的确定取值的工作集中在某处来完成。所以不妨在原来某个值未确定的位置多确定一些东西(有点绕啊)。
不妨设状态为 d p [ i ] [ j ] [ k ] dp[i][j][k] dp[i][j][k] 表示考虑前 i i i 个位置,当前的 M e x Mex Mex 值为 j j j,有 k k k 个未确定值的方案数。这里有一个未确定值的意思是,已经确定这个值由前面哪些位置取到,并且确定这个值在当前 M e x Mex Mex 上方,但是不确定具体位置。当然,也确定了 k k k 个未确定值的相对顺序。
此时的转移就变得方便很多。我们可以先考虑当前位置 i i i 的值确定,也就是 M e x Mex Mex 值会变动的转移。这部分的转移可以先从 i − 1 i-1 i−1 处转移过来,然后在 i i i 处按 j j j 递推,每次 M e x Mex Mex 值增加就取最靠近的未确定值补下来即可。
做完递推之后再考虑 M e x Mex Mex 值不变动的转移,这样就不会出错。
总复杂度 O ( n 2 k ) O(n^2k) O(n2k),不是很卡常,具体转移可以看代码。
代码
#include<bits/stdc++.h>//JZM yyds!!
#define ll long long
#define lll __int128
#define uns unsigned
#define fi first
#define se second
#define IF (it->fi)
#define IS (it->se)
#define END putchar('\n')
#define lowbit(x) ((x)&-(x))
#define inline jzmyyds
using namespace std;
const int MAXN=2005;
const ll INF=1e17;
ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+(s^48),s=getchar();
return f?x:-x;
}
int ptf[50],lpt;
void print(ll x,char c='\n'){
if(x<0)putchar('-'),x=-x;
ptf[lpt=1]=x%10;
while(x>9)x/=10,ptf[++lpt]=x%10;
while(lpt>0)putchar(ptf[lpt--]^48);
if(c>0)putchar(c);
}
const ll MOD=998244353;
int n,k,b[MAXN],C[MAXN][MAXN];
ll dp[2][MAXN][MAXN],ans;
void AD(ll&a,ll b){(a+=b)>=MOD?a-=MOD:233;}
int main()
{
n=read(),k=read();
for(int i=1;i<=n;i++)b[i]=read();
C[0][0]=1;
for(int i=1;i<=n;i++){
C[i][0]=1;
for(int j=1;j<=i;j++){
C[i][j]=C[i-1][j]+C[i-1][j-1];
if(C[i][j]>=MOD)C[i][j]-=MOD;
}
}
dp[0][0][0]=1;
int l=0,r=0;
for(int id=1;id<=n;id++){
bool e=id&1,t=e^1;
int nr=min(n,b[id]+k),nl=max(l,b[id]-k);
if(nl>nr)return print(0),0;
for(int i=l;i<=nr;i++)
for(int j=0;j<=id;j++)dp[e][i][j]=0; //初始化
for(int i=l;i<=r;i++)
for(int j=0;j<id;j++)
AD(dp[e][i+1][j],dp[t][i][j]); //第一遍转移
for(int i=l;i<nr;i++)
for(int j=1;j<=id;j++)
AD(dp[e][i+1][j-1],dp[e][i][j]); //递推
for(int i=l;i<=r;i++)
for(int j=0;j<id;j++)if(const ll d=dp[t][i][j]){ //第二遍转移
(dp[e][i][j]+=d*(i+j))%=MOD;
(dp[e][i][j+1]+=d*(j+1))%=MOD;
}
l=nl,r=nr;
}
for(int i=l;i<=r;i++)for(int j=0;j<=n-i;j++)
(ans+=dp[n&1][i][j]*C[n-i][j])%=MOD;
print(ans);
return 0;
}