Description
一个包含四个点的完全图,可以在任意节点出发,可以在任意节点结束,给出每个点被经过的次数,求有多少种合法的遍历序列。如果两个序列至少有一位是不同的,则认为它们不相同。
样例:
1 2 1 0
ABCB
BABC
BACB
BCAB
BCBA
CBAB
Input
多组数据。
对于每一组数据:
第一行四个数,分别表示4个点被经过的次数(每个数小于等于1000,经过次数可以为0)
Output
一个表示答案,对998244353取模.
Sample Input
2 3 3 3
Sample Output
12336
Solution
令 n=a+b+c+d
任意相邻两点不同的方案数=把
n
个字母随便放的方案数-至少有一对相邻点相等的方案数+至少有两对相邻点相等的方案数-…+至少有
把
a
个
故 ans=∑x=1n(−1)n−xx!∑i+j+k+l=xCi−1a−1Cj−1b−1Ck−1c−1Cl−1d−1i!j!k!l!
令 A[i]=Ci−1a−1i!,B[j]=Cj−1b−1j!,C[k]=Ck−1c−1k!,D[l]=Cl−1d−1l! ,对这四个序列做三次 NTT 即可
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxbit=14,maxlen=1<<maxbit,maxn=1005,mod=998244353,g=3;
int fact[maxn<<2],inv[maxn];
int wn[maxlen],inv2[maxbit+1];
int mod_pow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1)ans=(ll)ans*a%mod;
a=(ll)a*a%mod;
b>>=1;
}
return ans;
}
void init()
{
wn[0]=1,wn[1]=mod_pow(g,(mod-1)>>maxbit);
for(int i=2;i<maxlen;i++)wn[i]=(ll)wn[i-1]*wn[1]%mod;
inv2[0]=1,inv2[1]=(mod+1)/2;
for(int i=2;i<=maxbit;i++)inv2[i]=(ll)inv2[i-1]*inv2[1]%mod;//预处理2^i的逆元
fact[0]=1;
for(int i=1;i<=4000;i++)fact[i]=(ll)i*fact[i-1]%mod;
inv[1]=1;
for(int i=2;i<=1000;i++)inv[i]=mod-(ll)(mod/i)*inv[mod%i]%mod;
inv[0]=1;
for(int i=1;i<=1000;i++)inv[i]=(ll)inv[i]*inv[i-1]%mod;
}
void ntt(int *x,int len,int sta)
{
for(int i=0,j=0;i<len;i++)
{
if(i>j)swap(x[i],x[j]);
for(int l=len>>1;(j^=l)<l;l>>=1);
}
for(int i=1,d=1;d<len;i++,d<<=1)
for(int j=0;j<len;j+=d<<1)
for(int k=0;k<d;k++)
{
int t=(ll)wn[(maxlen>>i)*k]*x[j+k+d]%mod;
x[j+d+k]=x[j+k]-t<0?x[j+k]-t+mod:x[j+k]-t;
x[j+k]=x[j+k]+t>=mod?x[j+k]+t-mod:x[j+k]+t;
}
if(sta==-1)
{
reverse(x+1,x+len);
int bitlen=0;
while((1<<bitlen)<len)bitlen++;
int val=inv2[bitlen];
for(int i=0;i<len;i++)x[i]=(ll)x[i]*val%mod;
}
}
void NTT(int *a,int *b,int len)
{
ntt(a,len,1),ntt(b,len,1);
for(int i=0;i<len;i++)a[i]=(ll)a[i]*b[i]%mod;
ntt(a,len,-1);
}
void inc(int &x,int y)
{
x=x+y>=mod?x+y-mod:x+y;
}
void dec(int &x,int y)
{
x=x-y<0?x-y+mod:x-y;
}
int C(int n,int m)
{
return (ll)fact[n]*inv[m]%mod*inv[n-m]%mod;
}
int a[4],n,b[4][maxlen];
int main()
{
init();
while(~scanf("%d%d%d%d",&a[0],&a[1],&a[2],&a[3]))
{
n=a[0]+a[1]+a[2]+a[3];
memset(b,0,sizeof(b));
for(int i=0;i<4;i++)
for(int j=1;j<=a[i];j++)
b[i][j]=(ll)C(a[i]-1,j-1)*inv[j]%mod;
int len=1;
while(len<4*n)len<<=1;
for(int i=1;i<4;i++)NTT(b[0],b[i],len);
int ans=0;
for(int i=1;i<=n;i++)
if((n-i)&1)dec(ans,(ll)b[0][i]*fact[i]%mod);
else inc(ans,(ll)b[0][i]*fact[i]%mod);
printf("%d\n",ans);
}
return 0;
}