原题链接:https://codeforces.com/problemset/problem/1931/G
题目描述:
有 4 种拼图,其中第 i 种拼图有 ci 张。
两张拼图可以连结当且仅当它们相邻的卡槽中一个凹陷一个突出。
我们希望将所有的拼图从左往右拼起来,求总方案数。答案对 998244353 取模。
输入输出描述:
多测。1≤t≤2×10^5,0≤ci≤10^6,∑(c1+c2+c3+c4)≤4×10^6。
输入输出样例
输入
11
1 1 1 1
1 2 5 10
4 6 100 200
900000 900000 900000 900000
0 0 0 0
0 0 566 239
1 0 0 0
100 0 100 0
0 0 0 4
5 5 0 2
5 4 0 5
输出
4
66
0
794100779
1
0
1
0
1
36
126
解题思路:
四种图形的个数分别是c1,c2,c3,c4,首先只考虑第一种图形和第二种图形,如果abs(c1-c2)>1,那么必然多出一个第一种图形或者第二种图形无法插入链中,这个可以自己画图分析一下,这里不做具体描述,然后就是abs(c1-c2)<=1的情况。
当c1==c2时,如下图所示:
对于图形3会有c1种插入位置,图形4会有c2+1中插入位置,此时这个问题就可以转换为了将b个球放进a个盒子有多少种放法了,允许某些盒子为空,我们不妨设为f(a,b),实际上我们可以把这个过程看为将a+b个球放进a个盒子中有多少种放法,此时不允许出现空盒子,这里就相当于在a+b-1个空选择a-1个板子插入,就是C(a+b-1,a-1),对应到这里就是C(c1+c3-1,c1-1)*C(c2+c4,c2),然后如果按照[2,1,2,1,2,1]的方式排列,那么图形3有c1+1插入位置,图形4有c2种插入位置,那么就是C(c1+c3,c1)*C(c2+c4-1,c2-1),这俩种情况加起来即可。
当abs(c1-c2)==1时,如下图所示:
对于c1==c2-1或者c1==c2+1俩种情况,对于图形3和图形4都有max(c1,c2)种插入位置,同上分析,首先不妨设c=max(c1,c2),那么这里就是C(c+c3-1,c-1)*C(c+c4-1,c-1)。
还要注意特判一些边界情况,当c1==c2时,有可能c1==c2==0,那么如果此时c3!=0 && c4!=0,这种情况无法拼出一条链,输出0,也就是在图形1和图形2的个数都是0时,不允许图形3和图形4同时出现。
到这里就分析的差不多了,由于多测数据,我们需要提前预处理组合数,预处理之后,对于每组测试数据直接根据推导出的公式直接输出答案即可。
时间复杂度:O(T+nlog(n)),T表示测试数据组数,n=2e6。
空间复杂度:O(n),n=2e6。
cpp代码如下:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 2e6 + 10, mod = 998244353; //注意这里不是模(1e9+7),我开始习惯性写成(1e9+7)导致wa好几发
int T;
int fact[N],infact[N];
int qmi(int x,int k,int p)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%p;
x=1ll*x*x%p;
k>>=1;
}
return res;
}
void init() //预处理组合数
{
fact[0]=infact[0]=1;
for(int i=1;i<N;i++)
{
fact[i]=1ll*fact[i-1]*i%mod;
infact[i]=1ll*infact[i-1]*qmi(i,mod-2,mod)%mod;
}
}
int C(int n,int m)
{
if(n<0 || m<0)return 0;
return 1ll*fact[n]*infact[m]%mod*infact[n-m]%mod;
}
void solve()
{
int c1,c2,c3,c4;
cin>>c1>>c2>>c3>>c4;
if(abs(c1-c2)>1){
cout<<0<<'\n';
return ;
}
if(c1==c2){
if(!c1){ //特判边界情况
if(c3 && c4)cout<<0<<'\n';
else cout<<1<<'\n';
}else {
cout<<(1ll*C(c1+c3-1,c1-1)*C(c2+c4,c2)%mod+1ll*C(c1+c3,c1)*C(c2+c4-1,c2-1)%mod)%mod<<'\n';
}
}else {
int c=max(c1,c2);
cout<<1ll*C(c+c3-1,c-1)*C(c+c4-1,c-1)%mod<<'\n';
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
init();
cin>>T;
while(T--)
{
solve();
}
return 0;
}