传送门
题解:
在洛谷看到一群 O ( n 3 ) O(n^3) O(n3)AC的在嘲讽 O ( n 2 log n ) O(n^2\log n) O(n2logn)的。。。
但是实际上这道题有严格 O ( n 2 ) O(n^2) O(n2)的解法。。。
首先考虑容斥,计算出队伍里面至少有 i i i组同学在讨论 [数据删除] 的方案数进行容斥。
发现是个卷积,于是非常愉快可以
O
(
n
2
log
n
)
O(n^2\log n)
O(n2logn)上一个NTT。
但是跑得没有
O
(
n
3
)
O(n^3)
O(n3)快非常尴尬。
那么我们的想法其实很显然,考虑在前面已经确定了 j j j个位置的同学来讨论 [数据删除],剩下的位置全部未定,那么显然方案数就是四个EGF的乘积展开后总次数为 m m m的项之和,其中 m m m是未定位置数量。
我们可以看做是 ( x + y + z + w ) m (x+y+z+w)^m (x+y+z+w)m,其中 x x x的次数不超过 a a a, y y y的次数不超过 b b b, z z z的次数不超过 c c c, w w w的次数不超过 d d d的,所有玩意的系数之和。
考虑二项式展开两次,发现是一个组合数的区间和,直接维护即可。
那么确定有 j j j个位置的同学在讨论 [数据删除] 的方案数可以直接来一个背包DP,设 f [ i ] [ j ] [ 0 / 1 / 2 / 3 ] f[i][j][0/1/2/3] f[i][j][0/1/2/3]表示前面 i i i个位置已经确定了 j j j组同学讨论 [数据删除],最后一组的开头在 j , j − 1 , j − 2 , ≤ j − 3 j,j-1,j-2,\leq j-3 j,j−1,j−2,≤j−3的情况的方案数。
于是 O ( n 2 ) O(n^2) O(n2)此题得到解决。
代码:
#include<bits/stdc++.h>
#define ll long long
#define re register
#define cs const
using std::cerr;
using std::cout;
cs int mod=998244353;
inline int add(int a,int b){a+=b-mod;return a+(a>>31&mod);}
inline int dec(int a,int b){a-=b;return a+(a>>31&mod);}
inline int mul(int a,int b){ll r=(ll)a*b;return r>=mod?r%mod:r;}
inline void Inc(int &a,int b){a+=b-mod;a+=a>>31&mod;}
inline void Dec(int &a,int b){a-=b;a+=a>>31&mod;}
inline void Mul(int &a,int b){a=mul(a,b);}
cs int N=1e3+7;
int A,B,C,D,n,mx;
int c[N][N],s[N][N];
inline void init(){
for(int re i=0;i<=n;++i){
c[i][0]=s[i][0]=1;
for(int re j=1;j<=i;++j)
c[i][j]=add(c[i-1][j],c[i-1][j-1]),
s[i][j]=add(s[i][j-1],c[i][j]);
}
}
inline int calc(int n,int l,int r){
return dec(s[n][r],(l?s[n][l-1]:0));
}
int f[N][255][4];
signed main(){
#ifdef zxyoi
freopen("queue.in","r",stdin);
#endif
scanf("%d%d%d%d%d",&n,&A,&B,&C,&D);init();
mx=std::min(n>>2,std::min(std::min(A,B),std::min(C,D)));
f[0][0][3]=1;
for(int re i=1;i<=n;++i)
for(int re j=0;j<=mx;++j){
Inc(f[i][j][1],f[i-1][j][0]);
Inc(f[i][j][2],f[i-1][j][1]);
Inc(f[i][j][3],f[i-1][j][2]);
Inc(f[i][j][3],f[i-1][j][3]);
Inc(f[i][j+1][0],f[i-1][j][3]);
}int ans=0;
for(int re i=0;i<=mx;++i){
int other=0,sum=0;
other=n<4?1:add(add(f[n-3][i][0],f[n-3][i][1]),add(f[n-3][i][2],f[n-3][i][3]));
int m=n-(i<<2);
int A=::A-i,B=::B-i,C=::C-i,D=::D-i;
for(int re j=0;j<=m;++j){
int l1=std::max(0,j-A);
int r1=std::min(B,j);
int l2=std::max(0,m-j-C);
int r2=std::min(D,m-j);
if(l1>r1||l2>r2)continue;
Inc(sum,mul(c[m][j],mul(calc(j,l1,r1),calc(m-j,l2,r2))));
}
(i&1)?Dec(ans,mul(other,sum)):Inc(ans,mul(other,sum));
}
cout<<ans<<"\n";
return 0;
}