模意义下乘法运算的逆元(Modular Multiplicative Inverse),如何使用扩展欧几里德算法(Extended Euclidean algorithm)求解乘法逆元
逆元简介
如果一个线性同余方程 a x ≡ 1 ( m o d b ) ax \equiv 1 \pmod b ax≡1(modb),则 x x x 称为 a m o d b a \bmod b amodb 的逆元,记作 a − 1 a^{-1} a−1。
如何求逆元
扩展欧几里得法
int extend_gcd(int a, int b, int &x, int &y){
if (!b){
x = 1; y = 0;
return a;
}
int d = extend_gcd(b, a % b, y, x);
y -= (a/b) * x;
return d;
}
int mod_inverse(int a, int m){
int x, y;
extend_gcd(a, m, x, y);
return (m + x % m) % m;
}
扩展欧几里得法和求解线性同余方程是一个原理,在这里不展开解释。
快速幂法
因为 a x ≡ 1 ( m o d b ) ax \equiv 1 \pmod b ax≡1(modb);
所以 a x ≡ a b − 1 ( m o d b ) ax \equiv a^{b-1} \pmod b ax≡ab−1(modb)(根据 费马小定理);
所以 x ≡ a b − 2 ( m o d b ) x \equiv a^{b-2} \pmod b x≡ab−2(modb)。
然后我们就可以用快速幂来求了。
注意使用费马小定理需要限制 b b b 是一个素数,而扩展欧几里得算法只要求 gcd ( a , p ) = 1 \gcd(a, p) = 1 gcd(a,p)=1。
线性求逆元
求出 1 , 2 , . . . , n 1,2,...,n 1,2,...,n 中每个数关于 p p p 的逆元。
如果对于每个数进行单次求解,以上两种方法就显得慢了,很有可能超时,所以下面来讲一下如何线性( O ( n ) O(n) O(n))求逆元。
首先,很显然的 1 − 1 ≡ 1 ( m o d p ) 1^{-1} \equiv 1 \pmod p 1−1≡1(modp)
对于 ∀ p ∈ Z \forall p \in \mathbf{Z} ∀p∈Z,有 1 × 1 ≡ 1 ( m o d p ) 1 \times 1 \equiv 1 \pmod p 1×1≡1(modp) 恒成立,故在 p p p 下 1 1 1 的逆元是 1 1 1,而这是推算出其他情况的基础。
其次对于递归情况 i − 1 i^{-1} i−1,我们令 k = ⌊ p i ⌋ k = \lfloor \frac{p}{i} \rfloor k=⌊ip⌋, j = p m o d i j = p \bmod i j=pmodi,有 p = k i + j p = ki + j p=ki+j。再放到 m o d p \mod p modp 意义下就会得到: k i + j ≡ 0 ( m o d p ) ki+j \equiv 0 \pmod p ki+j≡0(modp);
两边同时乘 i − 1 × j − 1 i^{-1} \times j^{-1} i−1×j−1:
k j − 1 + i − 1 ≡ 0 ( m o d p ) kj^{-1}+i^{-1} \equiv 0 \pmod p kj−1+i−1≡0(modp)
i − 1 ≡ − k j − 1 ( m o d p ) i^{-1} \equiv -kj^{-1} \pmod p i−1≡−kj−1(modp)
再带入 j = p m o d i j = p \bmod i j=pmodi,有 p = k i + j p = ki + j p=ki+j,有:
i − 1 ≡ − ⌊ p i ⌋ ( p m o d i ) − 1 ( m o d p ) i^{-1} \equiv -\lfloor\frac{p}{i}\rfloor (p \bmod i)^{-1} \pmod p i−1≡−⌊ip⌋(pmodi)−1(modp)
我们注意到 p m o d i < i p \bmod i < i pmodi<i,而在迭代中我们完全可以假设我们已经知道了所有的模 p p p 下的逆元 j − 1 , j < i j^{-1}, j < i j−1,j<i。
故我们就可以推出逆元,利用递归的形式,而使用迭代实现:
i − 1 ≡ { 1 , if i = 1 , − ⌊ p i ⌋ ( p m o d i ) − 1 , otherwises . ( m o d p ) i^{-1} \equiv \begin{cases} 1, & \text{if } i = 1, \\ -\lfloor\frac{p}{i}\rfloor (p \bmod i)^{-1}, & \text{otherwises}. \end{cases} \pmod p i−1≡{1,−⌊ip⌋(pmodi)−1,if i=1,otherwises.(modp)
vector<int> get_inv(int n, int p) {
vector<int>inv(n + 1);
inv[1] = 1;
for (int i = 2; i <= n; ++i) inv[i] = ll(p - p / i) * inv[p % i] % p;
return inv;
}
使用 p − ⌊ p i ⌋ p-\lfloor \dfrac{p}{i} \rfloor p−⌊ip⌋ 来防止出现负数。
另外我们注意到我们没有对 inv[0]
进行定义却可能会使用它:当
i
∣
p
i | p
i∣p 成立时,我们在代码中会访问 inv[p % i]
,也就是 inv[0]
,这是因为当
i
∣
p
i | p
i∣p 时不存在
i
i
i 的逆元
i
−
1
i^{-1}
i−1。线性同余方程 中指出,如果
i
i
i 与
p
p
p 不互素时不存在相应的逆元(当一般而言我们会使用一个大素数,比如
1
0
9
+
7
10^9 + 7
109+7 来确保它有着有效的逆元)。因此需要指出的是:如果没有相应的逆元的时候,inv[i]
的值是未定义的。
另外,根据线性求逆元方法的式子: i − 1 ≡ − k j − 1 ( m o d p ) i^{-1} \equiv -kj^{-1} \pmod p i−1≡−kj−1(modp)
递归求解 j − 1 j^{-1} j−1, 直到 j = 1 j=1 j=1 返回 1 1 1。
中间优化可以加入一个记忆化来避免多次递归导致的重复,这样求 1 , 2 , . . . , n 1,2,...,n 1,2,...,n 中所有数的逆元的时间复杂度仍是 O ( n ) O(n) O(n)。
注意:如果用以上给出的式子递归进行单个数的逆元求解,目前已知的时间复杂度的上界为 O ( n 1 3 ) O(n^{\frac 1 3}) O(n31),具体请看 知乎讨论。算法竞赛中更好地求单个数的逆元的方法有扩展欧几里得法和快速幂法。
线性求任意 n 个数的逆元
上面的方法只能求 1 1 1 到 n n n 的逆元,如果需要求任意给定 n n n 个数( 1 ≤ a i < p 1 \le a_i < p 1≤ai<p)的逆元,就需要下面的方法:
首先计算 n n n 个数的前缀积,记为 s i s_i si,然后使用快速幂或扩展欧几里得法计算 s n s_n sn 的逆元,记为 s v n sv_n svn
因为 s v n sv_n svn 是 n n n 个数的积的逆元,所以当我们把它乘上 a n a_n an 时,就会和 a n a_n an 的逆元抵消,于是就得到了 a 1 a_1 a1 到 a n − 1 a_{n-1} an−1 的积逆元,记为 s v n − 1 sv_{n-1} svn−1
同理我们可以依次计算出所有的 s v i sv_i svi,于是 a i − 1 a_i^{-1} ai−1 就可以用 s i − 1 × s v i s_{i-1} \times sv_i si−1×svi 求得。
所以我们就在 O ( n + log p ) O(n + \log p) O(n+logp) 的时间内计算出了 n n n 个数的逆元。
ll qpow(ll x, ll n, ll mod) {
ll ans = 1;
while (n) {
if (n & 1)ans = (ans * x )% mod;
x = (x * x )% mod;
n >>= 1;
}
return ans % mod;
}
vector<int> get_inv(vector<int> a, int p) {
int n = a.size() - 1;
vector<int> s(n + 1), sv(n + 1), inv(n + 1);
s[0] = 1;
for (int i = 1; i <= n; ++i) s[i] = (ll)s[i - 1] * a[i] % p;
sv[n] = qpow(s[n], p - 2, p);
for (int i = n; i >= 1; --i) sv[i - 1] = (ll)sv[i] * a[i] % p;
for (int i = 1; i <= n; ++i) inv[i] = (ll)sv[i] * s[i - 1] % p;
return inv;
}
逆元练习题
#include<bits/stdc++.h>
#include <unordered_map>
using namespace std;
template<class...Args>
void debug(Args... args) {//Parameter pack
auto tmp = { (cout << args << ' ', 0)... };
cout << "\n";
}
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>pll;
typedef pair<int, int>pii;
const ll N = 5e3 + 5;
const ll INF = 0x7fffffff;
const ll MOD = 1e9 + 7;
vector<int> get_inv(int n, int p) {
vector<int>inv(n + 1);
inv[1] = 1;
for (int i = 2; i <= n; ++i) inv[i] = ll(p - p / i) * inv[p % i] % p;
return inv;
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int n, p;
cin >> n >> p;
vector<int> v=get_inv(n, p);
for (int i = 1; i <= n; i++)cout << v[i] << "\n";
return 0;
}
#include<bits/stdc++.h>
#include <unordered_map>
using namespace std;
template<class...Args>
void debug(Args... args) {//Parameter pack
auto tmp = { (cout << args << ' ', 0)... };
cout << "\n";
}
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>pll;
typedef pair<int, int>pii;
const ll N = 5e6 + 5;
const ll INF = 0x7fffffff;
const int MOD = 1e9 + 7;
#define re register
#define gc pa==pb&&(pb=(pa=buf)+fread(buf,1,100000,stdin),pa==pb)?EOF:*pa++
static char buf[100000],*pa(buf),*pb(buf);
inline int read() {
re int x(0);re char c(gc);
while(c<'0'||c>'9')c=gc;
while(c>='0'&&c<='9')
x=x*10+c-48,c=gc;
return x;
}
ll qpow(ll x, ll n, ll mod) {
ll ans = 1;
while (n) {
if (n & 1)ans = (ans * x )% mod;
x = (x * x )% mod;
n >>= 1;
}
return ans % mod;
}
vector<int> get_inv(vector<int> a, int p) {
int n = a.size() - 1;
vector<int> s(n + 1), sv(n + 1), inv(n + 1);
s[0] = 1;
for (int i = 1; i <= n; ++i) s[i] = (ll)s[i - 1] * a[i] % p;
sv[n] = qpow(s[n], p - 2, p);
for (int i = n; i >= 1; --i) sv[i - 1] = (ll)sv[i] * a[i] % p;
for (int i = 1; i <= n; ++i) inv[i] = (ll)sv[i] * s[i - 1] % p;
return inv;
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int n, p, k;
n=read();p=read();k=read();
vector<int>a(n + 1);
for (int i = 1; i <= n; i++)a[i]=read();
vector<int> v = get_inv(a, p);
ll ans = 0, t = k;
for (int i = 1; i <= n; i++) {
ans = (ans + t * (ll)v[i]) % p;
t = (t * k) % p;
}
cout << ans << "\n";
return 0;
}