[python刷题模板] 取模技巧/组合数取模模板/快速幂取模模板
- 一、 算法&数据结构
- 二、 模板代码
- 0. (常用)组合数取模,预处理O(n+lgn),询问O(1):[C(m,r) = m!%MOD * inv(r!)%MOD * inv((m-r)!)%MOD]:
- 1. 普通取模(略),但补充python3.7以前手写排列组合数公式:
- 2. 快速幂取模
- 3. 组合数取模,利用python的comb
- 4. 组合数取模m,r<2000,杨辉三角O(n^2),(也叫:帕斯卡恒等式):C(m,r) = C(m-1,r-1)+C(m-1,r)
- 5. 组合数取模m,r<1e5,MOD≈1e9,线性递推O(n),C(m,r)=C(m,r-1)*(m-r+1)/r
- 6. 组合数取模m,r<1e18,MOD<1e5,Lucas,卢卡斯定理
- 7. 组合数取模m,r<1e18,MOD<1e5,MOD可能是合数,扩展Lucas,扩展卢卡斯
- 11. 模板代码:快速幂求单个数的逆元(费马小定理)复杂度O(lgn)
- 12. 模板代码:快速幂求单个数的逆元(扩展欧几里得)复杂度O(lgn)
- 13. (最常用)模板代码:线性递推求n个数逆元inv[i] = (p - p / i) * (inv[p % i]) % p ,复杂度O(n)
- 14. (不如用线性递推)模板代码:欧拉定理,看不明白,以后再补充。
- 三、其他
- 四、更多例题
一、 算法&数据结构
1. 描述
有的题目的返回结果经常很大,为了不超过int32的范围,通常题目会要求答案对一个大质数MOD进行取模。
这个操作会造成的结果是:
一个小的数字对MOD取模依然是它本身;
而大的数字在特定运算下,满足同余定理,满足结合律和交换率,这样可以避免不同计算方法带来的结果不同。
加减乘,都满足同余性质;注意除法不满足,这带来了组合数计算的问题。
通常题目给出的MOD有:1e9+7、998244353,他们的特点是足够大且都是质数。
- 写在最前边:
- 遇到取模题,直接把这句话写在最后再继续看题:
return ans%MOD
。 - 一定不要忘记取模导致WA!
- 这里注意,我们用的是python:
- 首先python的int永远不会越界,原理是如果超过2^32会自动转大数,因此支持大数计算。
- 但千万不要因为这个就最后统一取模,大数的计算式非常慢的,会越界。因此中间步骤也要取模。
- python3.8以后添加了
math.comb
、math.perm
函数,如果求组合数可以直接计算,最后再取模(leetcode支持、acwing不支持)。 - perm如果不传入第二个参数,则
perm(n)=factorial(n)
,默认全排列=阶乘。 - 同样阶乘可以用
math.factorial
(python3.x)。 - 因此可以用阶乘模拟perm和comb。
- 幂运算可以
n**k
。
2. 复杂度分析
- 通常我们希望取模代价O(1)的。单词取模当然是O(1)的。
- 快速幂取模
- 组合数取模:
- 杨辉三角:O(n^2)。
3. 常见应用
- 普通取模。
- 快速幂取模。
- 组合数取模。
- 杨辉三角(m,r<2000):C(m,r) = C(m-1,r-1)+C(m-1,r)。
- C(m,r)为从m个数字中选r个数字的方案数,我们考虑取第m个数时的情况:任选一个数(假设是最左边的数),我们选择它,那么剩下m-1个数要选出r-1个数,方案数是C(m-1,r-1);如果不选它,要从剩下m-1个数选r个数,方案数C(m-1,r);因此总方案数C(m,r)=C(m-1,r-1)+C(m-1,r);特别的C(0,0)=1、C(i,0)=1;
- 线性递推(m,r<1e5,MOD≈1e9):C(m,r)=C(m,r-1)*(m-r+1)/r。
- 线性复杂度,但是公式存在除法,而除法不满足同余性质,因此要求逆元,把除法换为乘法。
- 如果要求单个数字的逆元我们可以用扩展欧几里得或者快速幂(建议exgcd,更快点);
- 如果要处理的是一堆,就是1-n所有数字的逆元,我们可以用欧拉筛打出逆元表,或者直接利用递推inv[i] = (p - p / i) * (inv[p % i]) % p
- 杨辉三角(m,r<2000):C(m,r) = C(m-1,r-1)+C(m-1,r)。
4. 常用优化
- 普通取模,如果认为数值的增长一次不会超过MOD,可以用减法来代替取模。
- 快速幂取模,打好板子的话,可以用循环代替递归。
- 组合数取模,杨辉三角、线性递推、卢卡斯定理。
二、 模板代码
0. (常用)组合数取模,预处理O(n+lgn),询问O(1):[C(m,r) = m!%MOD * inv(r!)%MOD * inv((m-r)!)%MOD]:
-
首先要知道公式C(m,r) = m!//(r!*(m-r)!)
-
我们知道除法不符合同余,因此要用逆元来代替除法,即C(m,r) = m!%MOD * inv(r!)%MOD * inv((m-r)!)%MOD
-
于是需要先预处理每个1~m的阶乘fact。
-
然后预处理阶乘的逆元,这里有两种方式:
- nlgn方法,每个数据都套用费马小定理或扩展欧几里得,那么每次lgn,总共O(nlgn)。
- 线性方法,把最后一个阶乘计算一次费马小定理。然后倒序递推计算前一个阶乘,我们知道fact[i-1] = fact[i]//i,把这个公式两边取逆元(倒数),得inv_f[i-1] = inv_f[i]*i,最后记得要取模。复杂度O(n+lgn)。
-
总结 ,预处理nlgn,询问O(1)
-
题外话:python的builtin.pow支持第三个参数mod,但是实测不如自己手写的快速幂快,会TLE。
class ModComb:
def __init__(self, n, p):
"""
初始化,为了防止模不一样,因此不写默认值,强制要求调用者明示
:param n:最大值
:param p: 模
"""
self.p = p
self.inv_f, self.fact = [1] * (n + 1), [1] * (n + 1)
inv_f, fact = self.inv_f, self.fact
for i in range(2, n + 1):
fact[i] = i * fact[i - 1] % p
inv_f[-1] = pow(fact[-1], p - 2, p)
for i in range(n, 0, -1):
inv_f[i - 1] = i * inv_f[i] % p
def comb(self, m, r):
if m < r or r < 0:
return 0
return self.fact[m] * self.inv_f[r] % self.p * self.inv_f[m - r] % self.p
def perm_count_with_duplicate(self, a):
"""含重复元素的列表a,全排列的种类。
假设长度n,含x种元素,分别计数为[c1,c2,c3..cx]
则答案是C(n,c1)*C(n-c1,c2)*C(n-c1-c2,c3)*...*C(cx,cx)
或:n!/c1!/c2!/c3!/../cn!
"""
ans = self.fact[len(a)]
for c in Counter(a).values():
ans = ans * self.inv_f[c] % self.p
return ans
# 下边这种也可以
# s = len(a)
# ans = 1
# for c in Counter(a).values():
# ans = ans * self.comb(s,c) % MOD
# s -= c
# return ans
CF例题。这里我试了如果不用线性求逆元,每个元素都用费马小定理,预处理nlogn 会TLE。
import sys
from collections import *
from itertools import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *
if sys.hexversion == 50924784:
sys.stdin = open('cfinput.txt')
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
# MOD = 10 ** 9 + 7
MOD = 998244353
"""https://codeforces.com/problemset/problem/1420/D
输入 n, k (1≤k≤n≤3e5) 和 n 个闭区间,区间的范围在 [1,1e9]。
你需要从 n 个区间中选择 k 个区间,且这 k 个区间的交集不为空。
输出方案数模 998244353 的结果。
输入
7 3
1 7
3 8
4 5
6 7
1 3
5 10
8 9
输出 9
输入
3 1
1 1
2 2
3 3
输出 3
输入
3 2
1 1
2 2
3 3
输出 0
输入
3 3
1 3
2 3
3 3
输出 1
输入
5 2
1 3
2 4
3 5
4 6
5 7
输出 7
"""
# 1778 ms
def solve1(n, k, lr):
def quick_pow_mod(a, b, p):
ans = 1
while b:
if b & 1:
ans = (ans * a) % p
a = (a * a) % p
b >>= 1
return ans
MAX_N = n + 1
inv_f, fact = [1] * MAX_N, [1] * MAX_N
p = MOD
for i in range(2, MAX_N):
fact[i] = i * fact[i - 1] % p
inv_f[MAX_N - 1] = pow(fact[MAX_N - 1], p - 2, p)
for i in range(MAX_N - 1, 0, -1):
inv_f[i - 1] = inv_f[i] * i % p
def get_c(m, r):
if m < r or r < 0:
return 0
# 公式C(m,r) = m!//(r!*(m-r)!)
return fact[m] * inv_f[r] % p * inv_f[m - r] % p
lr.sort()
h = []
ans = 0
for l, r in lr:
while h and h[0] < l:
heapq.heappop(h)
if len(h) >= k - 1:
ans = (ans + get_c(len(h), k - 1)) % MOD
heapq.heappush(h, r)
print(ans)
# 358 ms
def solve2(n, k, lr):
def quick_pow_mod(a, b, p):
ans = 1
while b:
if b & 1:
ans = (ans * a) % p
a = (a * a) % p
b >>= 1
return ans
MAX_N = n + 1
inv_f, fact = [1] * MAX_N, [1] * MAX_N
p = MOD
for i in range(2, MAX_N):
fact[i] = i * fact[i - 1] % p
inv_f[MAX_N - 1] = pow(fact[MAX_N - 1], p - 2, p)
for i in range(MAX_N - 1, 0, -1):
inv_f[i - 1] = inv_f[i] * i % p
def get_c(m, r):
if m < r or r < 0:
return 0
# 公式C(m,r) = m!//(r!*(m-r)!)
return fact[m] * inv_f[r] % p * inv_f[m - r] % p
l, r = [], []
for a, b in lr:
l.append(a)
r.append(b)
l.sort()
r.sort()
i = j = s = 0
ans = 0
while i < n and j < n:
if l[i] <= r[j]:
ans = (ans + get_c(s, k - 1)) % MOD
i += 1
s += 1
else:
j += 1
s -= 1
print(ans)
# 514 ms
def solve(n, k, lr):
def quick_pow_mod(a, b, p):
ans = 1
while b:
if b & 1:
ans = (ans * a) % p
a = (a * a) % p
b >>= 1
return ans
MAX_N = n + 1
inv_f, fact = [1] * MAX_N, [1] * MAX_N
p = MOD
for i in range(2, MAX_N):
fact[i] = i * fact[i - 1] % p
inv_f[MAX_N - 1] = pow(fact[MAX_N - 1], p - 2, p)
for i in range(MAX_N - 1, 0, -1):
inv_f[i - 1] = inv_f[i] * i % p
def get_c(m, r):
if m < r or r < 0:
return 0
# 公式C(m,r) = m!//(r!*(m-r)!)
return fact[m] * inv_f[r] % p * inv_f[m - r] % p
x = []
for a, b in lr:
x.append(a << 1)
x.append(b << 1 | 1)
x.sort()
s = ans = 0
for i in x:
if i & 1:
s -= 1
else:
ans = (ans + get_c(s, k - 1)) % MOD
s += 1
print(ans)
if __name__ == '__main__':
n, k = RI()
lr = []
for _ in range(n):
lr.append(RILST())
solve(n, k, lr)
1. 普通取模(略),但补充python3.7以前手写排列组合数公式:
import io
import os
import sys
from collections import deque
from math import factorial
if os.getenv('LOCALTESTACWING'):
sys.stdin = open('input.txt')
else:
input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
MOD = 998244353
if __name__ == '__main__':
n, m, k = map(int, input().split())
#def factorial(k):
# ans = 1
# for i in range(2,k+1):
# ans*=i
# return ans
def comb(m,r):
return factorial(m)//(factorial(r)*factorial(m-r))
def perm(m,r):
return factorial(m)//factorial(m-r)
ans = m*comb(n-1,k)*((m-1)**k)
print(ans % MOD)
2. 快速幂取模
链接: 1498. 满足条件的子序列数目
这题是排序后,寻找l\r长度,然后每个数字都可以出现或不出现,对答案贡献是2^n。
因此要对幂取模,如果是别的语言就要计算,python可以直接2**n%MOD。
题解在:[英雄星球七月集训LeetCode解题日报] 第9日 二分查找
MOD = 10**9+7
@cache
def fastPower(a,n):
ans = 1
base = a
while n != 0:
if (n & 1 != 0):
ans *= base % MOD
base *= base % MOD
n >>= 1
return ans % MOD
f = [1]
for i in range(10**5):
f.append(f[-1]*2%MOD)
class Solution:
def numSubseq(self, nums: List[int], target: int) -> int:
nums.sort()
n = len(nums)
l,r=0,n-1
ans = 0
while l<=r :
while r>=l and nums[l] + nums[r]>target:
r -= 1
if l>r:
break
# ans = (ans+2**(r-l))%MOD
ans = (ans+f[r-l])%MOD
# print(l,r,ans)
l += 1
return ans%MOD
3. 组合数取模,利用python的comb
链接: 2338. 统计理想数组的数目
一道周赛T4,难点不在这,但是用到了组合数取模。
[LeetCode周赛复盘-补] 第 300 场周赛20220710
MOD = 10**9+7
primes = [1] * 10001
primes[0] =0
primes[1] =1
for i in range(2,10000):
if primes[i]:
for j in range(2*i,10001,i):
primes[j] = 0
ABC = {2,3,5,7}
@cache
def get_prime_reasons(x):
if x == 1:
return Counter()
if primes[x]:
return Counter([x])
for i in range(2,int(x**0.5)+1):
if x % i == 0:
return get_prime_reasons(i) + get_prime_reasons(x//i)
class Solution:
def idealArrays(self, n: int, maxValue: int) -> int:
ans = 0
# print(get_prime_reasons(20))
for i in range(1,maxValue+1):
mul = 1
for k in get_prime_reasons(i).values():
mul = mul* comb(n+k-1,k)%MOD
ans = (ans+mul)%MOD
return ans%MOD
4. 组合数取模m,r<2000,杨辉三角O(n^2),(也叫:帕斯卡恒等式):C(m,r) = C(m-1,r-1)+C(m-1,r)
链接: AcWing 4496. 吃水果
第一次打acwing周赛的T3.用到了组合数。
参考链接: [acwing周赛复盘] 第 60 场周赛20220716
import io
import os
import sys
from collections import deque
from math import factorial
if os.getenv('LOCALTESTACWING'):
sys.stdin = open('input.txt')
else:
input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
MOD = 998244353
MAX_N = 2001
C = [[0] * MAX_N for _ in range(MAX_N)]
C[0][0] = 1
for i in range(1, MAX_N):
C[i][0] = 1
for j in range(1, MAX_N):
C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % MOD
if __name__ == '__main__':
n, m, k = map(int, input().split())
ans = m * C[n - 1][k] * ((m - 1) ** k)
print(ans % MOD)
5. 组合数取模m,r<1e5,MOD≈1e9,线性递推O(n),C(m,r)=C(m,r-1)*(m-r+1)/r
链接: AcWing 4496. 吃水果
例题同上。
参考链接: [acwing周赛复盘] 第 60 场周赛20220716
import io
import os
import sys
from collections import deque
from math import factorial
if os.getenv('LOCALTESTACWING'):
sys.stdin = open('input.txt')
else:
input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
MOD = 998244353
MAX_N = 2005
inv = [0] * MAX_N
inv[1] = 1
p = MOD
for i in range(2, MAX_N):
inv[i] = (p - p // i) * inv[p % i] % p
def get_c(m, r):
if m<r:
return 0
c = 1
for i in range(1, r + 1):
c = c * (m - i + 1) % MOD * inv[i] % MOD
return c
if __name__ == '__main__':
n, m, k = map(int, input().split())
ans = m * get_c(n - 1, k) * ((m - 1) ** k)
print(ans % MOD)
6. 组合数取模m,r<1e18,MOD<1e5,Lucas,卢卡斯定理
链接: AcWing 4496. 吃水果
这时由于r过大,线性就炸了,引入卢卡斯定理。
公式里依然提前实现(线性)C(m,r)再使用它,但是由于传进入的是C(m%p,r%p),在p较小的情况下,这个线性是可以忍受的。
因此实际上lucas定理只要m,r或者MOD有一方小就可以使用。
例题同上。
参考链接: [acwing周赛复盘] 第 60 场周赛20220716
import io
import os
import sys
from collections import deque
from math import factorial
if os.getenv('LOCALTESTACWING'):
sys.stdin = open('input.txt')
else:
input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
MOD = 998244353
MAX_N = 2005
inv = [0] * MAX_N
inv[1] = 1
p = MOD
for i in range(2, MAX_N):
inv[i] = (p - p // i) * inv[p % i] % p
def exgcd(a, b):
if b == 0:
return 1, 0, a
x, y, d = exgcd(b, a % b)
return y, x - a // b * y, d
# 求a关于b的逆元
def inv_exgcd(a, b):
x, y, d = exgcd(a, b)
return (x + b) % b if d == 1 else -1
def get_c(m, r):
if m < r:
return 0
c = 1
for i in range(1, r + 1):
c = c * (m - i + 1) % MOD * inv[i] % MOD
return c
def lucas(m, r):
if r == 0:
return 1
return get_c(m % MOD, r % MOD) * lucas(m // MOD, r // MOD) % MOD
if __name__ == '__main__':
n, m, k = map(int, input().split())
ans = m * lucas(n - 1, k) * ((m - 1) ** k)
print(ans % MOD)
7. 组合数取模m,r<1e18,MOD<1e5,MOD可能是合数,扩展Lucas,扩展卢卡斯
链接: AcWing 4496. 吃水果
先不写了,困了。
例题同上。
参考链接: [acwing周赛复盘] 第 60 场周赛20220716
11. 模板代码:快速幂求单个数的逆元(费马小定理)复杂度O(lgn)
费马小定理: ( a^p - a ) 是 p 的倍数,所以可推出 a^{{p-1}}\equiv 1{\pmod {p}} , 这也是更为常用的书写形式。
因为 a^(p-1) = a * a^(p-2) , 故费马小定理可写成逆元的形式,( a * a^(p-2) ) ≡ 1 (mod p)
因此 a^(p-2) 就是所求的逆元,用快速幂求出即可。
MOD = 10**9+7
def quick_pow_mod(a, b, p):
ans = 1
while b:
if b & 1:
ans = (ans * a) % p
a = (a * a) % p
b >>= 1
return ans
# 费马小定理求a关于p的逆元。
def fermat(a, p):
return quick_pow_mod(a, p - 2, p)
a = 7712312
b = fermat(a,MOD)
print(a*b%MOD)
12. 模板代码:快速幂求单个数的逆元(扩展欧几里得)复杂度O(lgn)
ax=1 (mod p) 相当于 ax-y*p=1 (y为整数,最小的x应当是y等于零的时候,所以最后的时候把y赋成0)
ax-yp=1 , 为了便于理解,将p写成b,
即, ax+by=1。 理解为 a关于 1 模 b 的乘法逆元为 x 。
开始的 x,y,d 为任意值,但是不能等于零!!!
因为d是 gcd(a,b),所以最后给d赋值为gcd的值。若d为1,说明存在这样的x,否则不存在。
MOD = 10 ** 9 + 7
def exgcd(a, b):
if b == 0:
return 1, 0, a
x, y, d = exgcd(b, a % b)
return y, x - a // b * y, d
# 求a关于b的逆元
def inv_exgcd(a, b):
x, y, d = exgcd(a, b)
return (x + b) % b if d == 1 else -1
a = 1
b = inv_exgcd(a, MOD)
print(a * b % MOD)
13. (最常用)模板代码:线性递推求n个数逆元inv[i] = (p - p / i) * (inv[p % i]) % p ,复杂度O(n)
令 a*x + b = p.
b * inv[b] ≡ 1 (mod p) , 将 b 替换为 p - a*x
(p - a * x) * inv[b] ≡ 1(mod p) ,即 p * inv[b] - (a * x * inv[b] ) ≡ 1(mod p)
因为 p mod p 等于零, 所以上式变为 -(a * x * inv[b]) ≡ 1(mod p)
观察 a * x + b = p 得 , 在计算机中 a = p/x , b = p%x .
故 - (p/x * inv[p % x] * x) ≡ 1(mod p)
因此 -p/x *inv[p % x] ≡ inv[x] (mod p)
MOD = 10 ** 9 + 7
MAX_N = 10 ** 5 + 5
inv = [0] * MAX_N
inv[1] = 1
p = MOD
for i in range(2, MAX_N):
inv[i] = (p - p // i) * inv[p % i] % p
a = 2675
b = inv[a]
print(a * b % MOD)
14. (不如用线性递推)模板代码:欧拉定理,看不明白,以后再补充。
摘自大佬博客: https://www.cnblogs.com/vongang/archive/2013/06/04/3117370.html
三、其他
- 如果还是卡常数,有些区间问题可以转化为树状数组,常数小,代码短,不过真的很难理解,还是线段树好写。遇到就套板吧。
四、更多例题
- 太多了。