原题链接:
C. The Intriguing Obsession
大意:
有若干个小岛,由三种颜色组成,现在在小岛之间添加桥,桥的长度为1,要求相同颜色的小岛之间的路长度不小于三,求有多少的方法。
思路:
可以看成一个三棱柱,每条棱代表一个颜色,易得棱上没有线段相连,那么考虑一个面上的线段连法。易得 不能有两个 A 棱上的点同时连到 B棱上的相同点。 线段数 i 从 0~min(a,b) ,对于 i,从 A 棱上选取点有C(a,i) 种,B 上取点有 C(b,i) 种 ,选取点后排列有 i! 种,所以一个面上种数为:
∑i=0min(a,b)C(a,i)×C(b,i)×i!
对于三个面乘法原理即可。
具体实现:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair <int,int> pii;
#define mem(s,t) memset(s,t,sizeof(s))
#define D(v) cout<<#v<<" "<<v<<endl
#define inf 0x3f3f3f3f
#define pb push_back
//#define LOCAL
const int mod=998244353;
const int MAXN =5e3+10;
//O(n)的算法
ll F[MAXN], Finv[MAXN], inv[MAXN];//F是阶乘,Finv是逆元的阶乘
void init(){
inv[1] = 1;
for(int i = 2; i < MAXN; i ++){
inv[i] = (mod - mod / i) * 1LL * inv[mod % i] % mod;
}
F[0] = Finv[0] = 1;
for(int i = 1; i < MAXN; i ++){
F[i] = F[i-1] * 1LL * i % mod;
Finv[i] = Finv[i-1] * 1LL * inv[i] % mod;
}
}
ll C(ll n, ll m){
if(m < 0 || m > n) return 0;
return F[n] * 1LL * Finv[n - m] % mod * Finv[m] % mod;
}
//计算
ll calc(ll a,ll b){
ll ret=0;
if(a<b) swap(a,b);
for(ll i=0;i<=min(a,b);i++){
ret+=C(a,i)*C(b,i)%mod*F[i];
ret%=mod;
}
return ret;
}
int main() {
#ifdef LOCAL
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
ll a,b,c;
cin>>a>>b>>c;
init();
ll ans=calc(a,b)*calc(b,c)%mod*calc(a,c)%mod;
cout<<ans<<endl;
return 0;
}
自己写的阶乘逆元 O(n)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair <int,int> pii;
#define mem(s,t) memset(s,t,sizeof(s))
#define D(v) cout<<#v<<" "<<v<<endl
#define inf 0x3f3f3f3f
#define pb push_back
//#define LOCAL
const int mod=998244353;
const int MAXN =5e3+10;
ll inv[MAXN],fac[MAXN];
ll quick_mod(ll a,ll b){
ll ret=1;
a%=mod;
while(b){
if(b&1) ret=ret*a%mod;
b>>=1;
a=a*a%mod;
}
return ret;
}
void init(){
mem(inv,0);
mem(fac,0);
fac[0]=fac[1]=1;
for(ll i=2;i<MAXN;i++) fac[i]=i*fac[i-1]%mod;
inv[MAXN-1]=quick_mod(fac[MAXN-1],mod-2);
for(ll i=MAXN-2;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod;
}
ll C(ll a,ll b){
if(b>a) return 0;
return fac[a]*inv[a-b]%mod*inv[b]%mod;
}
//计算
ll calc(ll a,ll b){
ll ret=0;
if(a<b) swap(a,b);
for(ll i=0;i<=min(a,b);i++){
ret+=C(a,i)*C(b,i)%mod*fac[i];
ret%=mod;
}
return ret;
}
int main() {
#ifdef LOCAL
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
ll a,b,c;
cin>>a>>b>>c;
// D(11);
init();
ll ans=calc(a,b)*calc(b,c)%mod*calc(a,c)%mod;
cout<<ans<<endl;
return 0;
}