题目描述
sol :
考点:容斥原理。
考虑枚举最终状态中恰好有 x 行全为黑, y 列全为黑。
即求 ∑ x = A n ∑ y = B m ( n x ) ( m y ) f ( n − x , m − y ) \sum_{x=A}^n\sum_{y=B}^m\binom{n}{x}\binom{m}{y}f(n-x,m-y) ∑x=An∑y=Bm(xn)(ym)f(n−x,m−y)
定义 f(a,b) 为每行每列都至少有一个白的方案数。
这里规定 f(0,0)=1 ,但是当其中一维为 0 时答案为 0 。
考虑 f(a,b) 怎么算。
我们可以考虑容斥:假设某一状态恰好有 x 行 和 y 列 为黑,那么对答案贡献为 0。
根据组合恒等式 ∑ i = 0 x ∑ j = 0 y ( − 1 ) i + j ( x i ) ( y j ) = 0 \sum_{i=0}^x\sum_{j=0}^y(-1)^{i+j}\binom{x}{i}\binom{y}{j}=0 ∑i=0x∑j=0y(−1)i+j(ix)(jy)=0 ,其中 x,y > 0 。
不难得到: f ( a , b ) = ∑ x = 0 a ∑ y = 0 b ( a x ) ( b y ) ( − 1 ) x + y 2 ( a − x ) ( b − y ) f(a,b)=\sum_{x=0}^a\sum_{y=0}^b\binom{a}{x}\binom{b}{y}(-1)^{x+y}2^{(a-x)(b-y)} f(a,b)=∑x=0a∑y=0b(xa)(yb)(−1)x+y2(a−x)(b−y)
这里要看到 x=a 或 y=b 的情况。仔细思考后发现就是一维现象的特例,所以贡献也是 0 。
(这里的容斥值得仔细品味)
那么最后就只剩下横纵都有白的情况,也就是 f(a,b) 。
然后把 后面那个 2 的次幂提出来,前面的组合数预处理,可以通过本题。
#include<bits/stdc++.h>
#define int long long
#define ll long long
using namespace std;
const int mod=998244353;
const int Maxn=3005;
ll fpow(ll x,ll y) {
ll mul(1);
for(;y;y>>=1) {
if(y&1) mul=mul*x%mod;
x=x*x%mod;
}
return mul;
}
//think.
int n,m,A,B;
ll fac[Maxn],inv[Maxn],ksm[Maxn*Maxn],ans1[Maxn],ans2[Maxn];
ll F(int x) {
if(x&1) {
return mod-1;
}
else {
return 1;
}
}
void init(int N) {
fac[0]=1;
for(int i=1;i<=N;i++) {
fac[i]=fac[i-1]*i%mod;
}
inv[N]=fpow(fac[N],mod-2);
for(int i=N;i>=1;i--) {
inv[i-1]=inv[i]*i%mod;
}
}
ll C(int x,int y) {
return fac[x]*inv[y]%mod*inv[x-y]%mod;
}
//eat shit
//shit you !
ll chk(int x,int y) {
if(!x||!y) return ksm[x+y];
return ksm[x*y];
}
ll solve(int x,int y) {
ll res=0;
for(int i=0;i<=x;i++) {
for(int j=0;j<=y;j++) {
res=(res+F(i+j)*C(x,i)%mod*C(y,j)%mod*ksm[(x-i)*(y-j)]%mod)%mod;
}
}
return res;
}
signed main() {
// freopen("data.in", "r", stdin);
init(3000);
while(scanf("%lld%lld%lld%lld",&n,&m,&A,&B)!=EOF) {
memset(ans1,0,sizeof ans1);
memset(ans2,0,sizeof ans2);
for(int i=0;i<=n;i++) {
for(int j=i;j<=n-A;j++) {
if(j&1) ans1[i]=(ans1[i]-C(n,j)%mod*C(j,i)%mod+mod)%mod;
else ans1[i]=(ans1[i]+C(n,j)%mod*C(j,i)%mod)%mod;
}
}
for(int i=0;i<=m;i++) {
for(int j=i;j<=m-B;j++) {
if (j&1) ans2[i]=(ans2[i]-C(m,j)%mod*C(j,i)%mod+mod)%mod;
else ans2[i]=(ans2[i]+C(m,j)%mod*C(j,i)%mod)%mod;
}
}
ll res=0;
for(int i=0;i<=n;i++) {
int dj=fpow(2,i);
for(int j=0,d=1;j<=m;j++,d=d*dj%mod) {
if ((i+j)&1) res=(res-d*ans1[i]%mod*ans2[j]%mod+mod)%mod;
else res=(res+d*ans1[i]%mod*ans2[j]%mod)%mod;
}
}
// for(int i=A;i<=n;i++) {
// for(int j=B;j<=m;j++) {
// res=(res+C(n,i)*C(m,j)*solve(n-i,m-j)%mod)%mod;
// }
// }
printf("%lld\n",res);
}
}