题目描述:
给出 N , M , K N,M,K N,M,K,求 ∑ i = 0 N ∑ j = 0 M ( i j ) [ i m o d 2 = 0 ] [ j m o d 2 = 0 ] \sum_{i=0}^N\sum_{j=0}^M\binom ij[i\!\!\!\!\mod2=0][j\!\!\!\!\mod2=0] i=0∑Nj=0∑M(ji)[imod2=0][jmod2=0]
N ≤ 1 0 9 , M ≤ 1 0 6 , K ≤ 1 0 9 N\le10^9,M\le10^6,K\le10^9 N≤109,M≤106,K≤109
题目分析:
要是把组合数换成下降幂,斯特林数没法求,枚举复杂度也很高。。
所以直接用二项式展开的系数表示组合数,对上指标隔项求和就可以表示为多项式的等比数列:
除
(
x
+
1
)
2
−
1
(x+1)^2-1
(x+1)2−1,发现上下两项常数项为0,其实就是除
x
−
2
x-2
x−2。
可以发现有
a
i
=
2
b
i
+
b
i
−
1
a_i=2b_i+b_{i-1}
ai=2bi+bi−1,其中
a
i
=
(
N
+
2
i
+
1
)
a_i=\binom {N+2}{i+1}
ai=(i+1N+2)
- 模数为奇数时,存在 2 2 2的逆元,所以 b 0 = a 0 ∗ 2 − 1 b_0=a_0*2^{-1} b0=a0∗2−1, b i = ( a i − b i − 1 ) ∗ 2 − 1 ( i > 0 ) b_i=(a_i-b_{i-1})*2^{-1}(i>0) bi=(ai−bi−1)∗2−1(i>0):
- 模数为 2 t 2^t 2t时,不存在 2 2 2的逆元,但是因为 b i = ∑ j > i ( − 2 ) j − ( i + 1 ) a j b_i=\sum_{j>i}(-2)^{j-(i+1)}a_j bi=∑j>i(−2)j−(i+1)aj,当 j − ( i + 1 ) ≥ t j-(i+1)\ge t j−(i+1)≥t时余数为0,所以可以直接算出 b M = ∑ t + M ≥ j > i a j b_M=\sum_{t+M\ge j>i}a_j bM=∑t+M≥j>iaj,然后用 b i − 1 = a i − 2 b i b_{i-1}=a_i-2b_i bi−1=ai−2bi递推。
所以最后的问题就是怎么算 a i a_i ai,因为要对每一个 0 ≤ i ≤ M 0\le i\le M 0≤i≤M都算出 ( N + 2 i + 1 ) \binom {N+2}{i+1} (i+1N+2),所以首先需要将模数拆分为质数的幂,然后用类似 e x l u c a s exlucas exlucas的方法将下降幂和阶乘中的 p p p的因子提出,剩下的数用 e x g c d exgcd exgcd求出逆元。于是可以 O ( m log m ) O(m\log m) O(mlogm)算出前 M M M个组合数。
最后将所有拆分的模数 C R T CRT CRT合并即可。
Code:
#include<bits/stdc++.h>
#define maxn 1000105
using namespace std;
int n,m,K,T,ans,mod,b,pw[maxn],C[maxn],inv[maxn],tmp;
void exgcd(int a,int b,int &x,int &y){
if(!b) {x=1,y=0;return;}
exgcd(b,a%b,y,x),y-=a/b*x;
}
void calc(const int p,const int k,int M){
int cnt=0,sum=1;
for(int i=pw[0]=1;i<=k;i++) pw[i]=pw[i-1]*p; const int pk = pw[k];
for(int i=1,lim=min(M,pk-1);i<=lim;i++) if(i%p) exgcd(i,pk,inv[i],tmp);
for(int i=1;i<=M;i++){
int x=n+2-i+1,y=i;
if(x) for(;x%p==0;x/=p,cnt++);
for(;y%p==0;y/=p,cnt--);
sum=1ll*sum*x%pk*inv[y%pk]%pk;
C[i] = cnt>=k ? 0 : 1ll*sum*pw[cnt]%pk;
}
}
void merge(int R,int P){
int i1,i2;
exgcd(mod,P,i1,i2); const int md = mod*P;
ans=(1ll*ans*P%md*i2+1ll*R*mod%md*i1)%md, mod=md;
}
int main()
{
scanf("%d%d%d",&n,&m,&K),n-=(n&1),m-=(m&1),m=min(n,m);
ans=0,mod=1;
while(!(K&1)) K>>=1,T++;
if(T){
calc(2,T,min(m+1+T,n+2)); const int md = 1<<T;
for(int i=min(m+1+T,n+2);i>m+1;i--) b=((-2ll)*b+C[i])%md;
int R=b;
for(int i=m-1;i>=0;i--){
b=((-2ll)*b+C[i+2])%md;
if(!(i&1)) R=(R+b)%md;
}
merge(R,md);
}
for(int i=3,x;i*i<=K;i++) if(K%i==0){
for(x=0; K%i==0; K/=i,x++);
calc(i,x,m+1); const int md=pw[x],iv2=(md+1)/2;
int R=b=1ll*C[1]*iv2%md;
for(int j=1;j<=m;j++){
b=1ll*(C[j+1]-b)*iv2%md;
if(!(j&1)) R=(R+b)%md;
}
merge(R,md);
}
if(K>1){
calc(K,1,m+1); const int md=K,iv2=(md+1)/2;
int R=b=1ll*C[1]*iv2%md;
for(int j=1;j<=m;j++){
b=1ll*(C[j+1]-b)*iv2%md;
if(!(j&1)) R=(R+b)%md;
}
merge(R,md);
}
printf("%d\n",(ans+mod)%mod);
}