求组合数有很多方法,需要根据数据范围去选择不同的方法。
C
a
b
=
a
(
a
−
1
)
(
a
−
2
)
.
.
.
.
(
a
−
b
+
1
)
1
∗
2
∗
3
∗
.
.
.
∗
b
=
a
!
b
!
(
a
−
b
)
!
C_a^b=\frac{a(a-1)(a-2)....(a-b+1)}{1*2*3*...*b}=\frac{a!}{b!(a-b)!}
Cab=1∗2∗3∗...∗ba(a−1)(a−2)....(a−b+1)=b!(a−b)!a!
1、取模的组合数计算
给定 n 组询问,每组询问给定两个整数 a,b,请你输出
C
b
a
m
o
d
(
1
0
9
+
7
)
C_b^a\ mod \ (10^9+7)
Cba mod (109+7) 的值。根据a和b以及n的规模,我们将采用不同的方式进行计算。
定义N=
1
0
9
+
7
10^9 + 7
109+7
1. 直接预处理结果
使用递推公式
C
a
b
=
C
a
−
1
b
+
C
a
−
1
b
−
1
C_a^b=C_{a-1}^b+C_{a-1}^{b-1}
Cab=Ca−1b+Ca−1b−1进行递推,预处理出所有c[i][j]数组,代表
C
i
j
C_i^j
Cij的值。
这里我们要求1<= b<=a<= 2000
const int N = 1e9 + 7,M = 2010;
int c[M][M];
//预处理代码
//i代表a
for(int i = 0; i < M; i ++)
{ //j代表b,b要始终小于等于a
for(int j = 0; j <= i; j ++)
{
if(!j) c[i][j] = 1;
else{
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % N;
}
}
}
预处理之后,任何传进来符合范围的数字均可以直接得到结果。
时间复杂度 :
O
(
N
2
)
O(N^2)
O(N2)
2. 预处理中间步骤
同样采用预处理的方式,只是预处理的方式有所不同,上面的我们直接预处理出来,所有的结果,这里我们预处理的是中间步骤。
C
a
b
m
o
d
M
=
a
!
b
!
(
a
−
b
)
!
m
o
d
M
=
a
!
(
b
!
)
−
1
(
(
a
−
b
)
!
)
−
1
m
o
d
M
C_a^b\ mod\ M=\frac{a!}{b!(a-b)!}\ mod\ M=a!(b!)^{-1}((a-b)!)^{-1}\ mod \ M
Cab mod M=b!(a−b)!a! mod M=a!(b!)−1((a−b)!)−1 mod M
这里-1代表乘法逆元。
所以我们只需要处理出所有a、b的阶乘及其逆元,即可快速拼装出结果,这里M为一个质数,所以可以使用快速幂进行实现。如果M是一个合数,就可以使用扩展欧几里得
可以参考
时间复杂度:
O
(
N
l
o
g
N
)
O(NlogN)
O(NlogN)
这里我们要求1 <= b <= a <= 10^5
#include<iostream>
using namespace std;
const int N = 1e5 + 10, M = 1e9 + 7;
int fact[N],infact[N];
//快速幂
int qmi(int a, int k, int p)
{
int res = 1;
while(k)
{
if(k & 1)
{
res = (long long)res * a % p;
}
a = (long long)a * a % p;
k >>= 1;
}
return res % p;
}
//初始化阶乘以及阶乘的逆元
void init()
{
fact[0] = infact[0] = 1;
for(int i = 1; i < N; i ++)
{
//求阶乘
fact[i] = (long long)fact[i - 1] * i % M;
//求阶乘逆元
infact[i] = (long long)infact[i - 1] * qmi(i,M - 2,M) % M;
}
}
int main()
{
int n;
cin >> n;
init();
for(int i = 0; i < n; i ++)
{
int a,b;
cin >> a >> b;
//组合成结果
cout << (long long)fact[a] * infact[a - b] % M *infact[b] % M << endl;
}
}
3. 卢卡斯定理
适用于a和b的取值范围超大(例如 1 <= b <= a <= 10^18),与此同时取模的数p(1<=p<=10^5)不算大的情况。
C
a
b
m
o
d
p
≡
C
a
m
o
d
p
b
m
o
d
p
∗
C
a
/
p
b
/
p
C_a^b\ mod\ p\equiv C_{a\ mod\ p}^{b\ mod\ p}*C_{a/p}^{b/p}
Cab mod p≡Ca mod pb mod p∗Ca/pb/p
int qmi(int a, int k, int p)
{
int res = 1;
while (k)
{
if (k & 1) res = (LL)res * a % p;
a = (LL)a * a % p;
k >>= 1;
}
return res;
}
//直接计算的方法
int C(int a, int b, int p)
{
if (b > a) return 0;
int res = 1;
for (int i = 1, j = a; i <= b; i ++, j -- )
{
res = (LL)res * j % p;
res = (LL)res * qmi(i, p - 2, p) % p;
}
return res;
}
int lucas(LL a, LL b, int p)
{
if (a < p && b < p) return C(a, b, p);
//a % p b % p 取值较小,可以直接进行计算。
//a / p b / p 取值不一定小,继续使用lucas计算
return (LL)C(a % p, b % p, p) * lucas(a / p, b / p, p) % p;
}
2、直接计算出结果(不取模)
对于直接计算的情况来说,会采取分解质因数的思路进行计算,具体如下:
C
a
b
=
a
(
a
−
1
)
(
a
−
2
)
.
.
.
.
(
a
−
b
+
1
)
1
∗
2
∗
3
∗
.
.
.
∗
b
=
a
!
b
!
(
a
−
b
)
!
C_a^b=\frac{a(a-1)(a-2)....(a-b+1)}{1*2*3*...*b}=\frac{a!}{b!(a-b)!}
Cab=1∗2∗3∗...∗ba(a−1)(a−2)....(a−b+1)=b!(a−b)!a!
- 筛出所有小于a的质数
- 计算出a!、b!、(a-b)!中,包含各个质数的个数
- 用a!的质数的个数减去b!质数的个数减去(a-b)!质数的个数,就是整个答案中,该质因数的个数
- 将所有的质因数按照数量相乘,就可以得到最终的结果,这一步需要高精度乘法实现
其中步骤计算出a!、b!、(a-b)!中,包含各个质数的个数可以使用下面的方式求出
a
!
=
a
p
+
a
p
2
+
a
p
3
.
.
.
+
a
p
n
+
.
.
.
.
a! = \frac{a}{p}+ \frac{a}{p^2}+ \frac{a}{p^3} ...+ \frac{a}{p^n}+....
a!=pa+p2a+p3a...+pna+....
const int N = 5010;
typedef long long LL;
int prime[N],cnt;
bool st[N];
//筛素数
void getP(int n)
{
for(int i = 2; i <= n; i ++)
{
if(!st[i])prime[cnt ++] = i;
for(int j = 0; prime[j] <= n / i; j ++)
{
st[prime[j] * i] = true;
if(i % prime[j] == 0)break;
}
}
}
//高精度乘法
vector<int> mul(vector<int> a ,int b)
{
int t = 0;
vector<int> res;
for(int i = 0; i < a.size(); i ++)
{
t = a[i] * b + t;
res.push_back(t % 10);
t /= 10;
}
while(t)
{
res.push_back(t % 10);
t /= 10;
}
return res;
}
//获取a!的阶乘中包含的质数p的个数
int getN(int a, int p)
{
int res = 0;
while(a)
{
res += a / p;
a /= p;
}
return res;
}
int main()
{
int a,b;
scanf("%d%d",&a,&b);
getP(a);
vector<int> ans;
ans.push_back(1);
for(int i = 0; i < cnt; i ++)
{
int p = prime[i];
int k = getN(a,p) - getN(b,p) - getN(a - b,p);
for(int j = 0; j < k; j ++)
{
ans = mul(ans,p);
}
}
for(int i = ans.size() - 1; i >= 0; i --)
{
printf("%d",ans[i]);
}
}