题目链接
题目解法
考虑一个套路的想法:
令
f
i
f_i
fi 为钦定有
i
i
i 个不合法段,其他随便排的方案数
令
g
i
g_i
gi 为恰好有
i
i
i 个不合法段的方案数
则有
f
i
=
∑
j
=
i
.
.
.
g
j
(
j
i
)
f_i=\sum_{j=i}^{...}g_j\binom{j}{i}
fi=∑j=i...gj(ij)
二项式反演得:
g
i
=
∑
j
=
i
.
.
.
(
−
1
)
j
−
i
f
j
(
j
i
)
g_i=\sum_{j=i}^{...}(-1)^{j-i}f_j\binom{j}{i}
gi=∑j=i...(−1)j−ifj(ij)
题目只需要求
i
=
0
i=0
i=0 的情况,则
a
n
s
=
∑
i
=
0
.
.
.
(
−
1
)
i
f
i
ans=\sum_{i=0}^{...}(-1)^if_i
ans=∑i=0...(−1)ifi
考虑求出
f
i
f_i
fi
令
t
o
t
(
a
,
b
,
c
,
d
,
n
)
tot(a,b,c,d,n)
tot(a,b,c,d,n) 为需要
n
n
n 个球,每种颜色有
a
,
b
,
c
,
d
a,b,c,d
a,b,c,d 个(
a
+
b
+
c
+
d
≥
n
a+b+c+d\ge n
a+b+c+d≥n),不同的染色方案数
因为带有排列,所以用
E
G
F
EGF
EGF 暴力卷起来肯定可以做,时间复杂度为
O
(
n
2
l
o
g
n
)
O(n^2logn)
O(n2logn)
考虑更优的解法
这里有一个重要的优化
t
r
i
c
k
trick
trick:把前两个颜色分别卷起来,后两个颜色卷起来,然后把两部分再卷起来
具体来说,对于正好
a
+
b
+
c
+
d
=
n
a+b+c+d=n
a+b+c+d=n 的情况,方案数为
n
!
a
!
b
!
c
!
d
!
\frac{n!}{a!b!c!d!}
a!b!c!d!n!
考虑
g
1
i
g1_i
g1i 维护所有情况的
1
a
!
b
!
\frac{1}{a!b!}
a!b!1
g
2
i
g2_i
g2i 维护所有情况的
1
c
!
d
!
\frac{1}{c!d!}
c!d!1
然后每次添加一个
a
,
b
,
c
,
d
a,b,c,d
a,b,c,d,然后维护当前添加的贡献即可
时间复杂度 O ( n 2 ) O(n^2) O(n2)
#include <bits/stdc++.h>
using namespace std;
const int N=1100,P=998244353;
int n,a,b,c,d,f[N],g1[N],g2[N];
int fac[N],inv[N];
inline int read(){
int FF=0,RR=1;
char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
return FF*RR;
}
int qmi(int a,int b){
int res=1;
for(;b;b>>=1){
if(b&1) res=1ll*res*a%P;
a=1ll*a*a%P;
}
return res;
}
int C(int a,int b){ return 1ll*fac[a]*inv[b]%P*inv[a-b]%P;}
void init(){
fac[0]=1,inv[0]=1;
for(int i=1;i<N;i++) fac[i]=1ll*fac[i-1]*i%P;
inv[N-1]=qmi(fac[N-1],P-2);
for(int i=N-2;i;i--) inv[i]=1ll*inv[i+1]*(i+1)%P;
}
int main(){
init();
n=read(),a=read(),b=read(),c=read(),d=read();
int mn=min(n/4,min(min(a,b),min(c,d)));
a-=mn,b-=mn,c-=mn,d-=mn;
for(int i=0;i<=a;i++) for(int j=0;j<=b;j++) g1[i+j]=(g1[i+j]+1ll*inv[i]*inv[j])%P;
for(int i=0;i<=c;i++) for(int j=0;j<=d;j++) g2[i+j]=(g2[i+j]+1ll*inv[i]*inv[j])%P;
// cout<<mn<<'\n';
for(int i=mn;i>=0;i--){
int t=n-4*i;
for(int j=0;j<=t;j++) f[i]=(f[i]+1ll*g1[j]*g2[t-j])%P;
f[i]=1ll*C(n-3*i,i)*f[i]%P*fac[n-4*i]%P;
a++,b++;
for(int j=0;j<a;j++) g1[j+b]=(g1[j+b]+1ll*inv[j]*inv[b])%P;
for(int j=0;j<=b;j++) g1[j+a]=(g1[j+a]+1ll*inv[a]*inv[j])%P;
c++,d++;
for(int j=0;j<c;j++) g2[j+d]=(g2[j+d]+1ll*inv[j]*inv[d])%P;
for(int j=0;j<=d;j++) g2[j+c]=(g2[j+c]+1ll*inv[j]*inv[c])%P;
// cerr<<"+++";
}
int ans=0;
for(int i=0,neg=1;i<=mn;i++,neg*=-1) ans=((ans+neg*f[i])%P+P)%P;
printf("%d",ans);
return 0;
}