卢卡斯定理
卢卡斯定理可以在模数较小的时候加速计算组合数,但要要求模数是质数:
C
(
n
,
m
)
=
C
(
n
%
p
,
m
%
p
)
×
C
(
n
/
p
,
m
/
p
)
C(n,m)=C(n\%p,m\%p)\times C(n/p,m/p)
C(n,m)=C(n%p,m%p)×C(n/p,m/p)
扩展卢卡斯
上面的定理只能解决模数是质数的情况,那么模数不是质数怎么办呢?这就要用到我们的扩展卢卡斯。
扩展卢卡斯的核心思想是把合数模数拆成若干个 p i k i p_i^{k_i} piki去算模出来的结果,然后用 C R T \tt CRT CRT合并即可。
我们的重点放在求模数是 p i k i p_i^{k_i} piki时的答案,由于阶乘里面的质因子 p i p_i pi会很大程度干扰我们的计算(如我们的逆元就不能在不互质的情况下求),那么我们先暴力把这些质因子提出来,我们用勒让德定理来算这些质因子,代码如下(这个定理又好证明又好理解):
for(int i=n;i;i/=pi) ind+=i/pi;
提出来以后我们的问题是算除去这些质因子的阶乘,这里我举一个经典的例子(模
9
9
9):
22
!
=
(
1
+
2
+
4
+
5
+
7
+
8
)
×
(
10
+
11
+
13
+
14
+
16
+
17
)
×
(
19
+
20
+
22
)
×
3
6
×
(
1
+
2
+
3
+
4
+
5
+
6
+
7
)
22!=(1+2+4+5+7+8)\times(10+11+13+14+16+17)\times(19+20+22)\\\times 3^6\times (1+2+3+4+5+6+7)
22!=(1+2+4+5+7+8)×(10+11+13+14+16+17)×(19+20+22)×36×(1+2+3+4+5+6+7)你会发现前两项在模意义下是一样的,所以我们暴力求出第一项然后快速幂即可。对于第三项是冗余部分可以暴力计算,最后一项可以递归计算(
f
a
c
(
22
/
3
)
fac(22/3)
fac(22/3) )
在扩展卢卡斯中逆元最好用 e x g c d \tt exgcd exgcd计算。
例题
#include <cstdio>
const int M = 100005;
#define int long long
int read()
{
int num=0,flag=1;char c;
while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
while(c>='0'&&c<='9')num=(num<<3)+(num<<1)+(c^48),c=getchar();
return num*flag;
}
int n,m,p;
void exgcd(int a,int b,int &x,int &y)
{
if(!b){x=1;y=0;return ;}
exgcd(b,a%b,y,x);y-=x*(a/b);
}
int inv(int v,int p)
{
int x,y;exgcd(v,p,x,y);
return (x%p+p)%p;
}
int qkpow(int a,int b,int p)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%p;
a=a*a%p;
b>>=1;
}
return r;
}
int fac(int n,int pi,int pk)
{
if(!n) return 1;int ans=1;
for(int i=2;i<pk;i++) if(i%pi) ans=ans*i%pk;
ans=qkpow(ans,n/pk,pk);
for(int i=2;i<=n%pk;i++) if(i%pi) ans=ans*i%pk;
return ans*fac(n/pi,pi,pk)%pk;//
}
int L(int n,int m,int pi,int pk)
{
int ind=0;
for(int i=n;i;i/=pi) ind+=i/pi;
for(int i=m;i;i/=pi) ind-=i/pi;
for(int i=n-m;i;i/=pi) ind-=i/pi;
int x=fac(n,pi,pk),y=fac(m,pi,pk),z=fac(n-m,pi,pk);
return x*inv(y,pk)%pk*inv(z,pk)%pk*qkpow(pi,ind,pk)%pk;
}
int C(int n,int m,int p)
{
int tmp=p,ans=0;
for(int i=2;i*i<=tmp;i++)
if(tmp%i==0)
{
int pk=1;
while(tmp%i==0) tmp/=i,pk*=i;
ans=(ans+L(n,m,i,pk)*inv(p/pk,pk)%p*p/pk%p)%p;
}
if(tmp>1) ans=(ans+L(n,m,tmp,tmp)*inv(p/tmp,tmp)%p*p/tmp%p)%p;
return ans;
}
signed main()
{
n=read();m=read();p=read();
printf("%lld\n",C(n,m,p));
}