[python刷题模板] 取模技巧/组合数取模模板/快速幂取模模板

一、 算法&数据结构

1. 描述

有的题目的返回结果经常很大,为了不超过int32的范围,通常题目会要求答案对一个大质数MOD进行取模。
这个操作会造成的结果是:
	一个小的数字对MOD取模依然是它本身;
	而大的数字在特定运算下,满足同余定理,满足结合律和交换率,这样可以避免不同计算方法带来的结果不同。
加减乘,都满足同余性质;注意除法不满足,这带来了组合数计算的问题。
通常题目给出的MOD有:1e9+7、998244353,他们的特点是足够大且都是质数。
  • 写在最前边:
  • 遇到取模题,直接把这句话写在最后再继续看题:return ans%MOD
  • 一定不要忘记取模导致WA!

  • 这里注意,我们用的是python:
  • 首先python的int永远不会越界,原理是如果超过2^32会自动转大数,因此支持大数计算。
  • 但千万不要因为这个就最后统一取模,大数的计算式非常慢的,会越界。因此中间步骤也要取模。
  • python3.8以后添加了math.combmath.perm函数,如果求组合数可以直接计算,最后再取模(leetcode支持、acwing不支持)。
  • perm如果不传入第二个参数,则perm(n)=factorial(n),默认全排列=阶乘。
  • 同样阶乘可以用math.factorial(python3.x)。
  • 因此可以用阶乘模拟perm和comb。
  • 幂运算可以n**k

2. 复杂度分析

  1. 通常我们希望取模代价O(1)的。单词取模当然是O(1)的。
  2. 快速幂取模
  3. 组合数取模:
    • 杨辉三角:O(n^2)。

3. 常见应用

  1. 普通取模。
  2. 快速幂取模。
  3. 组合数取模。
    • 杨辉三角(m,r<2000):C(m,r) = C(m-1,r-1)+C(m-1,r)。
      • C(m,r)为从m个数字中选r个数字的方案数,我们考虑取第m个数时的情况:任选一个数(假设是最左边的数),我们选择它,那么剩下m-1个数要选出r-1个数,方案数是C(m-1,r-1);如果不选它,要从剩下m-1个数选r个数,方案数C(m-1,r);因此总方案数C(m,r)=C(m-1,r-1)+C(m-1,r);特别的C(0,0)=1、C(i,0)=1;
    • 线性递推(m,r<1e5,MOD≈1e9):C(m,r)=C(m,r-1)*(m-r+1)/r。
      • 线性复杂度,但是公式存在除法,而除法不满足同余性质,因此要求逆元,把除法换为乘法。
      • 如果要求单个数字的逆元我们可以用扩展欧几里得或者快速幂(建议exgcd,更快点);
      • 如果要处理的是一堆,就是1-n所有数字的逆元,我们可以用欧拉筛打出逆元表,或者直接利用递推inv[i] = (p - p / i) * (inv[p % i]) % p

4. 常用优化

  1. 普通取模,如果认为数值的增长一次不会超过MOD,可以用减法来代替取模。
  2. 快速幂取模,打好板子的话,可以用循环代替递归。
  3. 组合数取模,杨辉三角、线性递推、卢卡斯定理。

二、 模板代码

0. (常用)组合数取模,预处理O(n+lgn),询问O(1):[C(m,r) = m!%MOD * inv(r!)%MOD * inv((m-r)!)%MOD]:

  • 首先要知道公式C(m,r) = m!//(r!*(m-r)!)

  • 我们知道除法不符合同余,因此要用逆元来代替除法,即C(m,r) = m!%MOD * inv(r!)%MOD * inv((m-r)!)%MOD

  • 于是需要先预处理每个1~m的阶乘fact。

  • 然后预处理阶乘的逆元,这里有两种方式:

    • nlgn方法,每个数据都套用费马小定理或扩展欧几里得,那么每次lgn,总共O(nlgn)。
    • 线性方法,把最后一个阶乘计算一次费马小定理。然后倒序递推计算前一个阶乘,我们知道fact[i-1] = fact[i]//i,把这个公式两边取逆元(倒数),得inv_f[i-1] = inv_f[i]*i,最后记得要取模。复杂度O(n+lgn)。
  • 总结 ,预处理nlgn,询问O(1)

  • 题外话:python的builtin.pow支持第三个参数mod,但是实测不如自己手写的快速幂快,会TLE。


class ModComb:
    def __init__(self, n, p):
        """
        初始化,为了防止模不一样,因此不写默认值,强制要求调用者明示
        :param n:最大值
        :param p: 模
        """
        self.p = p
        self.inv_f, self.fact = [1] * (n + 1), [1] * (n + 1)
        inv_f, fact = self.inv_f, self.fact
        for i in range(2, n + 1):
            fact[i] = i * fact[i - 1] % p
        inv_f[-1] = pow(fact[-1], p - 2, p)
        for i in range(n, 0, -1):
            inv_f[i - 1] = i * inv_f[i] % p

    def comb(self, m, r):
        if m < r or r < 0:
            return 0
        return self.fact[m] * self.inv_f[r] % self.p * self.inv_f[m - r] % self.p

    def perm_count_with_duplicate(self, a):
        """含重复元素的列表a,全排列的种类。
        假设长度n,含x种元素,分别计数为[c1,c2,c3..cx]
        则答案是C(n,c1)*C(n-c1,c2)*C(n-c1-c2,c3)*...*C(cx,cx)
        或:n!/c1!/c2!/c3!/../cn!
        """
        ans = self.fact[len(a)]
        for c in Counter(a).values():
            ans = ans * self.inv_f[c] % self.p
        return ans
        # 下边这种也可以
        # s = len(a)
        # ans = 1
        # for c in Counter(a).values():
        #     ans = ans * self.comb(s,c) % MOD
        #     s -= c
        # return ans

CF例题。这里我试了如果不用线性求逆元,每个元素都用费马小定理,预处理nlogn 会TLE。

import sys
from collections import *
from itertools import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *

if sys.hexversion == 50924784:
    sys.stdin = open('cfinput.txt')

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())

# MOD = 10 ** 9 + 7
MOD = 998244353
"""https://codeforces.com/problemset/problem/1420/D

输入 n, k (1≤k≤n≤3e5) 和 n 个闭区间,区间的范围在 [1,1e9]。
你需要从 n 个区间中选择 k 个区间,且这 k 个区间的交集不为空。
输出方案数模 998244353 的结果。
输入
7 3
1 7
3 8
4 5
6 7
1 3
5 10
8 9
输出 9

输入
3 1
1 1
2 2
3 3
输出 3

输入
3 2
1 1
2 2
3 3
输出 0

输入
3 3
1 3
2 3
3 3
输出 1

输入
5 2
1 3
2 4
3 5
4 6
5 7
输出 7
"""


#  1778 	 ms
def solve1(n, k, lr):
    def quick_pow_mod(a, b, p):
        ans = 1
        while b:
            if b & 1:
                ans = (ans * a) % p
            a = (a * a) % p
            b >>= 1
        return ans

    MAX_N = n + 1
    inv_f, fact = [1] * MAX_N, [1] * MAX_N
    p = MOD
    for i in range(2, MAX_N):
        fact[i] = i * fact[i - 1] % p
    inv_f[MAX_N - 1] = pow(fact[MAX_N - 1], p - 2, p)
    for i in range(MAX_N - 1, 0, -1):
        inv_f[i - 1] = inv_f[i] * i % p

    def get_c(m, r):
        if m < r or r < 0:
            return 0
        # 公式C(m,r) = m!//(r!*(m-r)!)
        return fact[m] * inv_f[r] % p * inv_f[m - r] % p

    lr.sort()
    h = []
    ans = 0
    for l, r in lr:
        while h and h[0] < l:
            heapq.heappop(h)
        if len(h) >= k - 1:
            ans = (ans + get_c(len(h), k - 1)) % MOD
        heapq.heappush(h, r)
    print(ans)


#  358  	 ms
def solve2(n, k, lr):
    def quick_pow_mod(a, b, p):
        ans = 1
        while b:
            if b & 1:
                ans = (ans * a) % p
            a = (a * a) % p
            b >>= 1
        return ans

    MAX_N = n + 1
    inv_f, fact = [1] * MAX_N, [1] * MAX_N
    p = MOD
    for i in range(2, MAX_N):
        fact[i] = i * fact[i - 1] % p
    inv_f[MAX_N - 1] = pow(fact[MAX_N - 1], p - 2, p)
    for i in range(MAX_N - 1, 0, -1):
        inv_f[i - 1] = inv_f[i] * i % p

    def get_c(m, r):
        if m < r or r < 0:
            return 0
        # 公式C(m,r) = m!//(r!*(m-r)!)
        return fact[m] * inv_f[r] % p * inv_f[m - r] % p

    l, r = [], []
    for a, b in lr:
        l.append(a)
        r.append(b)
    l.sort()
    r.sort()
    i = j = s = 0
    ans = 0
    while i < n and j < n:
        if l[i] <= r[j]:
            ans = (ans + get_c(s, k - 1)) % MOD
            i += 1
            s += 1
        else:
            j += 1
            s -= 1

    print(ans)

# 514 ms
def solve(n, k, lr):
    def quick_pow_mod(a, b, p):
        ans = 1
        while b:
            if b & 1:
                ans = (ans * a) % p
            a = (a * a) % p
            b >>= 1
        return ans

    MAX_N = n + 1
    inv_f, fact = [1] * MAX_N, [1] * MAX_N
    p = MOD
    for i in range(2, MAX_N):
        fact[i] = i * fact[i - 1] % p
    inv_f[MAX_N - 1] = pow(fact[MAX_N - 1], p - 2, p)
    for i in range(MAX_N - 1, 0, -1):
        inv_f[i - 1] = inv_f[i] * i % p

    def get_c(m, r):
        if m < r or r < 0:
            return 0
        # 公式C(m,r) = m!//(r!*(m-r)!)
        return fact[m] * inv_f[r] % p * inv_f[m - r] % p

    x = []
    for a, b in lr:
        x.append(a << 1)
        x.append(b << 1 | 1)
    x.sort()
    s = ans = 0
    for i in x:
        if i & 1:
            s -= 1
        else:
            ans = (ans + get_c(s, k - 1)) % MOD
            s += 1

    print(ans)


if __name__ == '__main__':
    n, k = RI()
    lr = []
    for _ in range(n):
        lr.append(RILST())

    solve(n, k, lr)

1. 普通取模(略),但补充python3.7以前手写排列组合数公式:

import io
import os
import sys
from collections import deque
from math import factorial

if os.getenv('LOCALTESTACWING'):
    sys.stdin = open('input.txt')
else:
    input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
MOD = 998244353
if __name__ == '__main__':
    n, m, k = map(int, input().split())
    #def factorial(k):
    #   ans = 1
    #   for i in range(2,k+1):
    #       ans*=i
    #   return ans
    def comb(m,r):
        return factorial(m)//(factorial(r)*factorial(m-r))
    def perm(m,r):
        return factorial(m)//factorial(m-r)
    ans = m*comb(n-1,k)*((m-1)**k)

    print(ans % MOD)

2. 快速幂取模

链接: 1498. 满足条件的子序列数目

这题是排序后,寻找l\r长度,然后每个数字都可以出现或不出现,对答案贡献是2^n。
因此要对幂取模,如果是别的语言就要计算,python可以直接2**n%MOD。
题解在:[英雄星球七月集训LeetCode解题日报] 第9日 二分查找

MOD = 10**9+7
@cache
def fastPower(a,n):
    ans = 1
    base = a
    while n != 0:
        if (n & 1 != 0):
            ans *= base % MOD
        base *= base % MOD
        n >>= 1
    return ans % MOD
f = [1]
for i in range(10**5):
    f.append(f[-1]*2%MOD)
class Solution:
    def numSubseq(self, nums: List[int], target: int) -> int:
        nums.sort()
        n = len(nums)
        l,r=0,n-1
        ans = 0
        
        while l<=r :
            while r>=l and nums[l] + nums[r]>target:
                r -= 1
            if l>r:
                break
            # ans = (ans+2**(r-l))%MOD 
            ans = (ans+f[r-l])%MOD 
            # print(l,r,ans)
            l += 1
            
        return ans%MOD

3. 组合数取模,利用python的comb

链接: 2338. 统计理想数组的数目

一道周赛T4,难点不在这,但是用到了组合数取模。
[LeetCode周赛复盘-补] 第 300 场周赛20220710

MOD = 10**9+7
primes = [1] * 10001
primes[0] =0
primes[1] =1
for i in range(2,10000):
    if primes[i]:
        for j in range(2*i,10001,i):
            primes[j] = 0

ABC = {2,3,5,7}
@cache
def get_prime_reasons(x):
    if x == 1:
        return Counter()
    if primes[x]:
        return Counter([x])
    for i in range(2,int(x**0.5)+1):
        if x % i == 0:
            return get_prime_reasons(i) + get_prime_reasons(x//i)       
    

class Solution:
    def idealArrays(self, n: int, maxValue: int) -> int:
        ans = 0
        # print(get_prime_reasons(20))
        for i in range(1,maxValue+1):
            mul = 1
            for k in get_prime_reasons(i).values():
                mul = mul* comb(n+k-1,k)%MOD
            ans = (ans+mul)%MOD

        return ans%MOD

4. 组合数取模m,r<2000,杨辉三角O(n^2),(也叫:帕斯卡恒等式):C(m,r) = C(m-1,r-1)+C(m-1,r)

链接: AcWing 4496. 吃水果

第一次打acwing周赛的T3.用到了组合数。
参考链接: [acwing周赛复盘] 第 60 场周赛20220716

import io
import os
import sys
from collections import deque
from math import factorial

if os.getenv('LOCALTESTACWING'):
    sys.stdin = open('input.txt')
else:
    input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
MOD = 998244353
MAX_N = 2001
C = [[0] * MAX_N for _ in range(MAX_N)]
C[0][0] = 1
for i in range(1, MAX_N):
    C[i][0] = 1
    for j in range(1, MAX_N):
        C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % MOD
if __name__ == '__main__':
    n, m, k = map(int, input().split())
    ans = m * C[n - 1][k] * ((m - 1) ** k)
    print(ans % MOD)

5. 组合数取模m,r<1e5,MOD≈1e9,线性递推O(n),C(m,r)=C(m,r-1)*(m-r+1)/r

链接: AcWing 4496. 吃水果

例题同上。
参考链接: [acwing周赛复盘] 第 60 场周赛20220716

import io
import os
import sys
from collections import deque
from math import factorial

if os.getenv('LOCALTESTACWING'):
    sys.stdin = open('input.txt')
else:
    input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
MOD = 998244353
MAX_N = 2005
inv = [0] * MAX_N
inv[1] = 1
p = MOD
for i in range(2, MAX_N):
    inv[i] = (p - p // i) * inv[p % i] % p


def get_c(m, r):
	if m<r:
		return 0
    c = 1
    for i in range(1, r + 1):
        c = c * (m - i + 1) % MOD * inv[i] % MOD
    return c


if __name__ == '__main__':
    n, m, k = map(int, input().split())
    ans = m * get_c(n - 1, k) * ((m - 1) ** k)
    print(ans % MOD)

6. 组合数取模m,r<1e18,MOD<1e5,Lucas,卢卡斯定理

链接: AcWing 4496. 吃水果
这时由于r过大,线性就炸了,引入卢卡斯定理。
公式里依然提前实现(线性)C(m,r)再使用它,但是由于传进入的是C(m%p,r%p),在p较小的情况下,这个线性是可以忍受的。
因此实际上lucas定理只要m,r或者MOD有一方小就可以使用。
例题同上。
参考链接: [acwing周赛复盘] 第 60 场周赛20220716

import io
import os
import sys
from collections import deque
from math import factorial

if os.getenv('LOCALTESTACWING'):
    sys.stdin = open('input.txt')
else:
    input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
MOD = 998244353
MAX_N = 2005
inv = [0] * MAX_N
inv[1] = 1
p = MOD
for i in range(2, MAX_N):
    inv[i] = (p - p // i) * inv[p % i] % p


def exgcd(a, b):
    if b == 0:
        return 1, 0, a
    x, y, d = exgcd(b, a % b)
    return y, x - a // b * y, d


# 求a关于b的逆元
def inv_exgcd(a, b):
    x, y, d = exgcd(a, b)
    return (x + b) % b if d == 1 else -1


def get_c(m, r):
    if m < r:
        return 0
    c = 1
    for i in range(1, r + 1):
        c = c * (m - i + 1) % MOD * inv[i] % MOD
    return c


def lucas(m, r):
    if r == 0:
        return 1
    return get_c(m % MOD, r % MOD) * lucas(m // MOD, r // MOD) % MOD


if __name__ == '__main__':
    n, m, k = map(int, input().split())

    ans = m * lucas(n - 1, k) * ((m - 1) ** k)

    print(ans % MOD)

7. 组合数取模m,r<1e18,MOD<1e5,MOD可能是合数,扩展Lucas,扩展卢卡斯

链接: AcWing 4496. 吃水果
先不写了,困了。
例题同上。
参考链接: [acwing周赛复盘] 第 60 场周赛20220716



11. 模板代码:快速幂求单个数的逆元(费马小定理)复杂度O(lgn)

费马小定理: ( a^p - a ) 是 p 的倍数,所以可推出 a^{{p-1}}\equiv 1{\pmod {p}} , 这也是更为常用的书写形式。

因为 a^(p-1) = a * a^(p-2) , 故费马小定理可写成逆元的形式,( a * a^(p-2) ) ≡ 1 (mod p)

因此 a^(p-2) 就是所求的逆元,用快速幂求出即可。


MOD = 10**9+7
def quick_pow_mod(a, b, p):
    ans = 1
    while b:
        if b & 1:
            ans = (ans * a) % p
        a = (a * a) % p
        b >>= 1
    return ans


# 费马小定理求a关于p的逆元。
def fermat(a, p):
    return quick_pow_mod(a, p - 2, p)


a = 7712312
b = fermat(a,MOD)
print(a*b%MOD)

12. 模板代码:快速幂求单个数的逆元(扩展欧几里得)复杂度O(lgn)

ax=1 (mod p) 相当于 ax-y*p=1 (y为整数,最小的x应当是y等于零的时候,所以最后的时候把y赋成0)

ax-yp=1 , 为了便于理解,将p写成b,

即, ax+by=1。 理解为 a关于 1 模 b 的乘法逆元为 x 。

开始的 x,y,d 为任意值,但是不能等于零!!!

因为d是 gcd(a,b),所以最后给d赋值为gcd的值。若d为1,说明存在这样的x,否则不存在。

MOD = 10 ** 9 + 7


def exgcd(a, b):
    if b == 0:
        return 1, 0, a
    x, y, d = exgcd(b, a % b)
    return y, x - a // b * y, d

# 求a关于b的逆元
def inv_exgcd(a, b): 
    x, y, d = exgcd(a, b)
    return (x + b) % b if d == 1 else -1


a = 1
b = inv_exgcd(a, MOD)
print(a * b % MOD)

13. (最常用)模板代码:线性递推求n个数逆元inv[i] = (p - p / i) * (inv[p % i]) % p ,复杂度O(n)

令 a*x + b = p.

b * inv[b] ≡ 1 (mod p) , 将 b 替换为 p - a*x

(p - a * x) * inv[b] ≡ 1(mod p) ,即 p * inv[b] - (a * x * inv[b] ) ≡ 1(mod p)

因为 p mod p 等于零, 所以上式变为 -(a * x * inv[b]) ≡ 1(mod p)

观察 a * x + b = p 得 , 在计算机中 a = p/x , b = p%x .

故 - (p/x * inv[p % x] * x) ≡ 1(mod p)

因此 -p/x *inv[p % x] ≡ inv[x] (mod p)

MOD = 10 ** 9 + 7
MAX_N = 10 ** 5 + 5
inv = [0] * MAX_N
inv[1] = 1
p = MOD
for i in range(2, MAX_N):
    inv[i] = (p - p // i) * inv[p % i] % p

a = 2675
b = inv[a]
print(a * b % MOD)

14. (不如用线性递推)模板代码:欧拉定理,看不明白,以后再补充。

摘自大佬博客: https://www.cnblogs.com/vongang/archive/2013/06/04/3117370.html

在这里插入图片描述


三、其他

  1. 如果还是卡常数,有些区间问题可以转化为树状数组,常数小,代码短,不过真的很难理解,还是线段树好写。遇到就套板吧。

四、更多例题

  • 太多了。
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值