对于同余方程组:
? ≡ ?1 (??? ?1)
? ≡ ?2 (??? ?2)
…
? ≡ ?? (??? ??)
若有 ?1, ?2 … ?? 互质,可以用普通的中国剩余定理求解。
但若 ?1, ?2 … ?? 不互质,就需要用到扩展中国剩余定理。
它的原理和普通的中国剩余定理类似,也是通过合并同余方程构造解。
区别就在于,因为?1, ?2… 不互质,所以需要把方程两两合并,
例如合并1、2,就可以得到
? ≡ ?1 (??? ?1)
? ≡ ?2 (??? ?2)
↓
? ≡ ?' ( ??? (?1?2 / $gcd$(?1, ?2) ) )
以此类推,如果合并了所有方程,得到的?就是最终的解。
那么,这一过程是如何实现的?
证明
通过前两个方程可以得出
∵ $?$ ≡ $?$1 + $?$1 $y$1
$?$ ≡ $?$2 + $?$2 $y$2
∴ $?$1 + $?$1 $y$1 = $?$2 + $?$2 $y$2
$?$1 $y$1 - $?$2 $y$2 = $?$2 - $?$1
这个式子中,$?$1,$?$2,$?$2 - $?$1 是已知的,且在模?1?2 / $gcd$(?1, ?2)时成立
可以发现,它刚好符合$Ax+By=C$的形式。
和普通的中国剩余定理求逆元时的操作类似,可以用$exgcd$求出$x$,也就是$y$1。
(先用$exgcd$求出$Ax+By=gcd(A,B)$,可知要求的$x'=C/gcd(A.B)*x$,然后将结果+B%B+B变为最小正数)
通过方程$?$ ≡ $?$1 + $?$1 $y$1构造出一个模?1?2 / $gcd$(?1, ?2)下的解$x$,
于是就得到了新方程
$?$ ≡ $?$1 + $?$1 $y$1 ( ??? (?1?2 / $gcd$(?1, ?2) ) )
用新的$x$代替$?$1,再将新方程与3合并……以此类推,合并所有方程后构造出的解$x$即为答案。
这部分的代码如下:
void excrt() { for(int i = 2; i <= n; i++) { A = m[1], B = m[i], C = a[i]-a[1]; C = (C%B+B)%B; int g = exgcd(A,B,x,y); x = x*(C/g); x = (x%B+B)%B; a[1] = a[1]+ m[1]*x; m[1] = m[1]*(m[i]/g); a[1] = (a[1]%m[1]+m[1])%m[1]; } }
板子题:Luogu P4777 【模板】扩展中国剩余定理(EXCRT)
因为数据可能很大,所以需要用快速乘经常取模,并且先除后乘。
代码如下
![](https://i-blog.csdnimg.cn/blog_migrate/8f900a89c6347c561fdf2122f13be562.gif)
![](https://i-blog.csdnimg.cn/blog_migrate/961ddebeb323a10fe0623af514929fc1.gif)
#include<cstdio> #include<iostream> #include<cmath> #include<cstring> #define MogeKo qwq #define int long long using namespace std; const int maxn = 1e6; int n,x,y,A,B,C; int m[maxn],a[maxn],ans; int qmul(int a,int b,int mo){ int ans = 0,base = a; while(b){ if(b&1) ans = (ans+base) %mo; base = (base+base) %mo; b >>= 1; } return ans; } int exgcd(int a,int b,int &x,int &y) { if(!b) { x=1, y=0; return a; } int g = exgcd(b,a%b,x,y); int tx = x; x = y; y = tx-(a/b)*y; return g; } void excrt() { for(int i = 2; i <= n; i++) { A = m[1], B = m[i], C = a[i]-a[1]; C = (C%B+B)%B; int g = exgcd(A,B,x,y); x = qmul(x,(C/g),B); x = (x%B+B)%B; a[1] = a[1]+ qmul(m[1],x,m[1]*(m[i]/g)); m[1] = m[1]*(m[i]/g); a[1] = (a[1]%m[1]+m[1])%m[1]; } } main() { scanf("%lld",&n); for(int i = 1; i <= n; i++) scanf("%lld%lld",&m[i],&a[i]); excrt(); printf("%lld",a[1]); }