题意:
有四类人排队,每类人分别喜欢唱、跳、rap、篮球,分别由a,b,c,d个人,队伍长度n。如果任意k,k,k+1,k+2,k+3四个位置上的人依次喜欢唱、跳、rap、篮球,则不合法,求和法的排列方法。n,a,b,c,d<=1e3
思路:
容斥,分别求至少包含1,2,…,个四人组不合法。用指数生成函数即可
#include <bits/stdc++.h>
using namespace std;
const int N = 4040;
long long a[N],b[N],nn = 1,rev[N],w1[N],w2[N];
const int mod = 998244353;
inline int power(int di,int ci) {
int ret = 1;
while (ci) {
if (ci&1)
ret = (long long)ret*di%mod;
di = (long long)di*di%mod;
ci >>= 1;
}
return ret;
}
inline long long inv(int x) {
return power(x,mod-2);
}
inline void NTT(long long *x,int I) {
int i,j;
long long t0,t1,*w;
int k;
for (i = 0;i < nn; i++)
if (rev[i] > i)
swap(x[rev[i]],x[i]);
w = (I == 1?w1:w2);
for (i = 1;i < nn; i <<= 1) {
for (j = 0;j < nn; j += (i<<1)) {
for (k = 0;k < i; k++) {
t0 = x[j|k],t1 = (long long)w[i|k]*x[i|j|k]%mod;
x[j|k] = (t0+t1)%mod;
x[i|j|k] = ((t0-t1)%mod+mod)%mod;
}
}
}
if (I == -1)
for (int i = 0;i < nn; i++)
x[i] = (long long)x[i]*inv(nn)%mod;
}
int half;
int aa,bb,cc,dd,n;
void calc() {
for (int i = 0;i < half; i++)
w1[i|half] = power(3,(mod-1)/nn*i);
for (int i = half-1;i>0; --i)
w1[i] = w1[i<<1];
for (int i = 1;i < nn; i++)
w2[i] = inv(w1[i]);
NTT(a,1);
NTT(b,1);
for (int i = 0;i < nn; i++)
a[i] = (long long)b[i]*a[i]%mod;
NTT(a,-1);
for (int i = n+1;i <= nn; i++)
a[i] = 0;
}
long long njc[1010];
inline void work(int p) {
memset(a,0,sizeof(a));
memset(b,0,sizeof(b));
for (int i = 0;i <= min(aa-p,n); i++)
a[i] = njc[i];
for (int i = 0;i <= min(bb-p,n); i++)
b[i] = njc[i];
calc();
memset(b,0,sizeof(b));
for (int i = 0;i <= min(cc-p,n); i++)
b[i] = njc[i];
calc();
memset(b,0,sizeof(b));
for (int i = 0;i <= min(dd-p,n); i++)
b[i] = njc[i];
calc();
}
long long C[1010][1010];
long long f[1010];
int main() {
scanf("%d%d%d%d%d",&n,&aa,&bb,&cc,&dd);
C[0][0] = 1;
for (int i = 0;i <= 1000; i++) {
C[i][i] = C[i][0] = 1;
for (int j = 1;j < i; j++)
C[i][j] = (C[i-1][j]+C[i-1][j-1])%mod;
}
njc[0] = 1;
while (nn <= n+n)
nn <<= 1;
half = nn/2;
for (int i = 1;i < nn; i++)
rev[i] = (rev[i>>1]>>1)|((i&1)?half:0);
for (int i = 1;i <= n; i++) {
njc[i] = njc[i-1]*inv(i)%mod;
}
long long ans = 0;
for (int i = 0;i <= n/4; i++) {
if (i > aa || i > bb || i > cc || i > dd)
break;
work(i);
f[i] = a[n-4*i]*inv(njc[n-4*i])%mod*C[n-3*i][i]%mod;
if (i&1)
ans -= f[i];
else
ans += f[i];
// cout<< i << " " << f[i] << endl;
}
for (int i = n/4;~i; i--) {
for (int j = i+1;j <= n/4; j++)
(f[i] -= f[j]*C[j][i]) %= mod;
}
f[0] += mod;
f[0] %= mod;
ans %= mod;
ans += mod;
ans %= mod;
// cout << ans << endl;
printf("%lld",f[0]);
return 0;
}