标签:组合数学
思路
这个式子指的是n个数中选k个数相乘,所有组合的结果相加
而注意到本题a的范围只有三个数,这就大大降低了题目难度,如果某个组合中有0,那么这个组合对最终的答案贡献为0,所以我们可以直接不管为0的项,在其余的项中选择即可
我们假设为1的项有t1个,为2的项有t2个,如果选择i个1,方案数为C(t1,i)那么就需要选择k-i个2,方案数为C(t2,k-i),为1的项不会对组合的贡献产生影响,所以这个组合的贡献就是2k-i,这种选法的总贡献就是C(t1,i)* C(t2,k-i )* 2k-i,最终的答案是对于i(i为选择1的数量)属于[0,k],所有选法的总和
还有一些细节的地方,比如k可能会大于t1,这样在循环i的时候,i是可能大于t1的,这时这种选法的贡献应当为0(因为项的数量不够选出i个)
代码实现
本题数据范围较大,因此计算组合数时应该先预处理出阶乘数组,而且计算组合数会用到除法,本题又需要取模,那么还应该预处理出逆元数组
代码如下:
#include<stdio.h>
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 1e7 + 5;
const int mod = 998244353;
int n, k;
int a[N];
long long f[N], inv[N];
long long qpow(long long x, long long y){
long long sum = 1;
while(y){
if(y&1) sum = sum*x%mod;
x = x*x%mod;
y>>=1;
}
return sum;
}
long long cal(int x, int y){
if(x > y) return 0; //选择的数量大于总数量时贡献为0
return f[y]*inv[x]%mod*inv[y-x]%mod;
}
int main(){
scanf("%d%d",&n,&k);
int t1 = 0, t2 = 0;
for(int i = 1; i <= n; i++){
scanf("%d",&a[i]);
if(a[i] == 1) t1++;
if(a[i] == 2) t2++;
}
f[0] = 1, inv[0] = 1;
int maxn = max(t1, t2);
for(int i = 1; i <= maxn; i++){
f[i] = f[i-1]*i%mod;
//inv[i] = inv[i-1]*qpow(i, mod-2)%mod; 这里如果写成这样会超时
}
inv[maxn] = qpow(f[maxn], mod-2);
for(int i = maxn - 1; i >= 1; i--){
inv[i] = inv[i+1]*(i+1)%mod;
}
long long ans = 0;
for(int i = 0; i <= k; i++){
ans = (ans + cal(i, t1)*cal(k-i, t2)%mod*qpow(2, k-i)%mod)%mod;
}
printf("%lld",ans);
return 0;
}