题目
大中锋的学院要组织学生参观博物馆,要求学生们在博物馆中排成一队进行参观。他的同学可以分为四类:一部分最喜欢唱、一部分最喜欢跳、一部分最喜欢rap,还有一部分最喜欢篮球。如果队列中 k , k + 1 , k + 2 , k + 3 k,k + 1,k + 2,k + 3 k,k+1,k+2,k+3位置上的同学依次,最喜欢唱、最喜欢跳、最喜欢rap、最喜欢篮球,那么他们就会聚在一起讨论蔡徐坤。大中锋不希望这种事情发生,因为这会使得队伍显得很乱。大中锋想知道有多少种排队的方法,不会有学生聚在一起讨论蔡徐坤。两个学生队伍被认为是不同的,当且仅当两个队伍中至少有一个位置上的学生的喜好不同。由于合法的队伍可能会有很多种,种类数对998244353取模。
输入格式
输入数据只有一行。每行5个整数,第一个整数n,代表大中锋的学院要组织多少人去参观博物馆。接下来四个整数a、b、c、d,分别代表学生中最喜欢唱的人数、最喜欢跳的人数、最喜欢rap的人数和最喜欢篮球的人数。保证
a
+
b
+
c
+
d
≥
n
a+b+c+d \ge n
a+b+c+d≥n。
输出格式
每组数据输出一个整数,代表你可以安排出多少种不同的学生队伍,使得队伍中没有学生聚在一起讨论蔡徐坤。结果对
998244353
998244353
998244353取模。
输入输出样例
输入
4 4 3 2 1
输出
174
输入
996 208 221 132 442
输出
442572391
说明/提示
对于20%的数据,有
n
=
a
=
b
=
c
=
d
≤
500
n=a=b=c=d\le500
n=a=b=c=d≤500
对于100%的数据,有
n
≤
1000
n \le 1000
n≤1000 ,
a
,
b
,
c
,
d
≤
500
a, b, c, d \le 500
a,b,c,d≤500
题解
考虑枚举有
i
i
i堆人在讨论cxk小姐姐,那么被cxk迷倒的人就有
4
i
4i
4i个,且每一堆人是挨在一起的,
所以我们可以把这
4
i
4i
4i个人缩成
i
i
i个群体,每一个群体可以展开成为
4
4
4人
那么现在就只用考虑剩下的
n
−
3
i
n-3i
n−3i人,放置的方案数就是
C
n
−
3
i
i
C_{n-3i}^i
Cn−3ii
简单用整体法证明一下:
在
n
−
3
i
n-3i
n−3i人中,要选
i
i
i个拿来讨论,
n
−
4
i
n-4i
n−4i个不讨论,方案数是对应的
C
n
−
3
i
i
=
C
n
−
3
i
n
−
4
i
C_{n-3i}^i=C_{n-3i}^{n-4i}
Cn−3ii=Cn−3in−4i
也可以理解为在
n
−
3
i
n-3i
n−3i个空格里面找
i
i
i个起点开始讨论小姐姐
接着枚举好了
i
i
i及它们可能出现的地方
C
n
−
3
i
i
C_{n-3i}^i
Cn−3ii,我们就要去考虑剩下
n
−
4
i
n-4i
n−4i人的排列,可以乱来
(
n
−
4
i
)
!
(n-4i)!
(n−4i)!
但是
a
t
t
e
n
t
i
o
n
attention
attention,如果看了我的上一篇指数型生成函数专练,就会与这里有一点相通
题目说的方案数不同的要求是至少有一个位置上的学生的喜好不同,而不是学生本人不同
所以我们要除掉喜好篮球,跳舞,唱歌,rap本身内部的乱拍,因为从爱好上来看是看不出来排列不同的
设
a
,
b
,
c
,
d
a,b,c,d
a,b,c,d分别代表学生中最喜欢唱的人数、最喜欢跳的人数、最喜欢rap的人数和最喜欢篮球的总人数,我们考虑的是剩下的不讨论cxk姐姐的学生乱排,所以要剪掉外层所枚举的去参与讨论的人数
(
n
−
4
i
)
!
(
a
−
i
)
!
(
n
−
i
)
!
(
c
−
i
)
!
(
d
−
i
)
!
\frac{(n-4i)!}{(a-i)!(n-i)!(c-i)!(d-i)!}
(a−i)!(n−i)!(c−i)!(d−i)!(n−4i)!
最后写出每种喜好的生成函数,以喜欢篮球的人为例:
∑
i
=
0
c
x
i
i
!
\sum_{i=0}^{c}\frac{x^i}{i!}
i=0∑ci!xi
把四种爱好卷起来,卷出最后乘积的第
n
−
4
i
n-4i
n−4i项就是我们需要除掉的东西,乘上
(
n
−
4
i
)
!
(n-4i)!
(n−4i)!就是真正的乱排数
因为带取模,所以是用 N T T NTT NTT跑,除的话就要变成乘逆元,所以我们可以在卷的时候就直接卷逆元
但是我们又发现虽然假设的是
i
i
i堆人在讨论,但是统计答案的时候却把
≥
i
\ge i
≥i堆人讨论的情况都统计了
而且当统计
i
=
1
i=1
i=1时,会算两遍至少两堆人讨论的方案,三遍至少三堆人讨论的方案…在统计
i
=
2
i=2
i=2时,会算三遍至少三堆人讨论的方案…
当统计
i
i
i的方案时,会多算
C
j
i
C_j^i
Cji次至少
j
j
j堆人讨论的方案,所以我们可以用容斥来解决
总结一下答案应该是
∑
i
=
0
m
i
n
(
−
1
)
i
∗
C
n
−
3
i
i
∗
(
n
−
4
i
)
!
(
a
−
i
)
!
(
n
−
i
)
!
(
c
−
i
)
!
(
d
−
i
)
!
\sum_{i=0}^{min}(-1)^i*C_{n-3i}^i*\frac{(n-4i)!}{(a-i)!(n-i)!(c-i)!(d-i)!}
i=0∑min(−1)i∗Cn−3ii∗(a−i)!(n−i)!(c−i)!(d−i)!(n−4i)!
code1(NTT)
#include <cstdio>
#include <iostream>
using namespace std;
#define mod 998244353
#define LL long long
#define MAXN 10005
int n, anum, bnum, cnum, dnum;
LL pi, ni, result;
LL a[MAXN], b[MAXN], c[MAXN], d[MAXN], rev[MAXN], fac[MAXN], Invfac[MAXN];
LL qkpow ( LL x, LL y ) {
LL ans = 1;
while ( y ) {
if ( y & 1 )
ans = ans * x % mod;
x = x * x % mod;
y >>= 1;
}
return ans;
}
LL C ( int n, int m ) {
return fac[n] * Invfac[m] % mod * Invfac[n - m] % mod;
}
void NTT ( LL *c, LL limit, LL f ) {
for ( LL i = 0;i < limit;i ++ )
if ( i < rev[i] )
swap ( c[i], c[rev[i]] );
for ( LL i = 1;i < limit;i <<= 1 ) {
LL omega = qkpow ( f == 1 ? pi : ni, ( mod - 1 ) / ( i << 1 ) );
for ( LL j = 0;j < limit;j += ( i << 1 ) ) {
LL w = 1;
for ( LL k = 0;k < i;k ++, w = w * omega % mod ) {
LL x = c[k + j], y = w * c[i + j + k] % mod;
c[k + j] = ( x + y ) % mod;
c[k + j + i] = ( x - y + mod ) % mod;
}
}
}
LL inv = qkpow ( limit, mod - 2 );
if ( f == -1 )
for ( LL i = 0;i < limit;i ++ )
c[i] = c[i] * inv % mod;
}
void init () {
Invfac[0] = fac[0] = 1;
for ( int i = 1;i <= n;i ++ )
fac[i] = fac[i - 1] * i % mod;
Invfac[n] = qkpow ( fac[n], mod - 2 );
for ( int i = n - 1;i;i -- )
Invfac[i] = Invfac[i + 1] * ( i + 1 ) % mod;
}
LL solve ( int n, int A, int B, int C, int D ) {
LL len = 1, l = 0;
while ( len < ( ( A + B + C + D ) << 1 ) ) {
len <<= 1;
l ++;
}
for ( LL i = 0;i < len;i ++ )
rev[i] = ( rev[i >> 1] >> 1 ) | ( ( i & 1 ) << ( l - 1 ) );
for ( int i = 0;i < len;i ++ )
a[i] = ( i <= A ? Invfac[i] : 0 );
for ( int i = 0;i < len;i ++ )
b[i] = ( i <= B ? Invfac[i] : 0 );
for ( int i = 0;i < len;i ++ )
c[i] = ( i <= C ? Invfac[i] : 0 );
for ( int i = 0;i < len;i ++ )
d[i] = ( i <= D ? Invfac[i] : 0 );
NTT ( a, len, 1 );NTT ( b, len, 1 );NTT ( c, len, 1 );NTT ( d, len, 1 );
for ( int i = 0;i < len;i ++ )
a[i] = a[i] * b[i] % mod * c[i] % mod * d[i] % mod;
NTT ( a, len, -1 );
return a[n] * fac[n] % mod;
}
int main() {
pi = 3;
ni = mod / pi + 1;
scanf ( "%d %d %d %d %d", &n, &anum, &bnum, &cnum, &dnum );
init();
int k = min ( min ( min ( min ( anum, bnum ), cnum ), dnum ), n / 4 );
anum -= k;bnum -= k;cnum -= k;dnum -= k;
result = 0;
for ( k;k >= 0;k -- ) {
LL tmp = C ( n - 3 * k, k ) % mod * solve ( n - 4 * k, anum, bnum, cnum, dnum ) % mod;
anum ++;bnum ++;cnum ++;dnum ++;
result = ( result + ( ( k & 1 ) ? mod - tmp : tmp ) ) % mod;
}
printf ( "%lld\n", result );
return 0;
}
code2(EGF+卷积)
#include <cstdio>
#include <iostream>
using namespace std;
#define mod 998244353
#define LL long long
#define MAXN 1005
int n, a, b, c, d;
LL result;
LL foldAB[MAXN], foldCD[MAXN], fac[MAXN], Invfac[MAXN];
LL qkpow ( LL x, LL y ) {
LL ans = 1;
while ( y ) {
if ( y & 1 )
ans = ans * x % mod;
x = x * x % mod;
y >>= 1;
}
return ans;
}
LL C ( int n, int m ) {
return fac[n] * Invfac[m] % mod * Invfac[n - m] % mod;
}
void Fac () {
Invfac[0] = fac[0] = 1;
for ( int i = 1;i <= n;i ++ )
fac[i] = fac[i - 1] * i % mod;
Invfac[n] = qkpow ( fac[n], mod - 2 );
for ( int i = n - 1;i;i -- )
Invfac[i] = Invfac[i + 1] * ( i + 1 ) % mod;
}
int main() {
scanf ( "%d %d %d %d %d", &n, &a, &b, &c, &d );
Fac();
int k = min ( min ( min ( min ( a, b ), c ), d), n / 4 );
a -= k;b -= k;c -= k;d -= k;
for ( int i = 0;i <= a;i ++ )
for ( int j = 0;j <= b;j ++ )
foldAB[i + j] = ( foldAB[i + j] + Invfac[i] * Invfac[j] % mod ) % mod;
for ( int i = 0;i <= c;i ++ )
for ( int j = 0;j <= d;j ++ )
foldCD[i + j] = ( foldCD[i + j] + Invfac[i] * Invfac[j] % mod ) % mod;
result = 0;
for ( k;k >= 0;k -- ) {
LL ans = 0;
int tp = n - 4 * k;
for ( int i = 0;i <= tp;i ++ )
ans = ( ans + foldAB[i] * foldCD[tp - i] % mod ) % mod;
ans = ans * fac[tp] % mod * C ( n - k * 3, k ) % mod;
result = ( result + ( ( k & 1 ) ? mod - ans : ans ) ) % mod;
a ++;b ++;c ++;d ++;
for ( int i = 0;i <= a;i ++ )
foldAB[i + b] = ( foldAB[i + b] + Invfac[i] * Invfac[b] % mod ) % mod;
for ( int i = 0;i <= b;i ++ )
foldAB[i + a] = ( foldAB[i + a] + Invfac[i] * Invfac[a] % mod ) % mod;
foldAB[a + b] = ( foldAB[a + b] - Invfac[a] * Invfac[b] % mod + mod ) % mod;
for ( int i = 0;i <= c;i ++ )
foldCD[i + d] = ( foldCD[i + d] + Invfac[i] * Invfac[d] % mod ) % mod;
for ( int i = 0;i <= d;i ++ )
foldCD[i + c] = ( foldCD[i + c] + Invfac[i] * Invfac[c] % mod ) % mod;
foldCD[c + d] = ( foldCD[c + d] - Invfac[c] * Invfac[d] % mod ) % mod;
}
printf ( "%lld\n", result );
return 0;
}
代码或者思路有任何问题欢迎评论