扩展欧几里得算法详解
对于一个方程a∗x+b∗y=gcd(a,b) 来说,我们可以做如下的推导:
设有a∗x1+b∗y1=gcd(a,b); ①
同时我们有b∗x2+(a%b)∗y2=gcd(b,a%b); ②
- 因为 (a%b)=(a-[a/b]∗b); 所以公式②可以转化为b∗x2+(a[a/b]∗b)∗y2=gcd(b,a%b); ③
- ③化简得a∗y2+b∗(x2- [a/b]∗y2)=gcd(b,a%b); ④;
- 因为gcd(a,b)=gcd(b,a%b);
所以我们最终的到两个式子:
a∗x1+b∗y1=gcd(a,b); ①
a∗y2+b∗(x2- [a/b]∗y2)=gcd(b,a%b); ②
对于这个方程组,我们可以知道的是x1, x2,y1,y2之间的关系:
x1=y2
y1=x2−[a/b]∗y2
这个递归的边界是什么呢?我们知道,当朴素欧几里得到达边界时,return gcd(a,0)=a,那么边界条件就是对a∗x0+0∗y0=a求解,很显然,此时x0=1,y0=0(这是a∗x+b∗y=gcd
(a,b)的一个特解)
当我们递归求出了一个方程的特解时,如何求出这个方程的通解呢?
方程a∗x+b∗y=gcd(a,b) 中,如果将x加上一个常数k1,y减去一个常数k2,仍然保持原方程成立,那么x+k1,y−k2就是方程的一个新解,这个k应该如何选择呢?
实际上很简单a∗(x+k1)+b∗(y+k2)=gcd(a,b),打开括号,a∗x+a∗k1+b∗y−b∗k2=gcd(a,b);
我们保证原方程成立,就需要a∗k1 == b ∗k2,那么显然k1=b,k2=a是一种合理的情况,但是这样是无法包含所有整数解的,因为我们加上的这个值并非是最小值
那我们应该加上什么值才行呢?我们发现当a∗k1 == b∗k2=t∗lcm(a,b) 可以保证得到所有解,于是每次寻找解就可以分别在x加上b/gcd(a,b),在y减去a/gcd(a,b)
对于方程a∗x+b∗y=c我们又该如何求解?我们发现如果 (c%gcd(a,b)!=0) 那么这个方程是无解的,而如果gcd(a,b)∗t == c,我们就可以按上面的方法求解之后对我们的解乘上一个t(t=c/gcd(a,b));
#include<cstdio>
#include<cmath>
#include<iostream>
using namespace std;
int exgcd(int a,int b,int &x,int &y)//扩展欧几里得算法
{
if(b==0)
{
x=1;y=0;
return a; //到达递归边界开始向上一层返回
}
int r=exgcd(b,a%b,x,y);
int temp=y; //把x y变成上一层的
y=x-(a/b)*y;
x=temp;
return r; //得到a b的最大公因数
}
int main()
{
int x,y,a=6,b=7,c=2;
int k1,k2,k;
int gcd=exgcd(a,b,x,y); //求a*x+b*y=c的一组解;a,b,c是已知的,gcd是a和b的最大公约数
if(c%gcd!=0)
printf("Impossible\n"); //如果c不是最大公约数的倍数,那么该等式无整数解
printf("%d\n",gcd);
printf("%d %d\n",x,y);//方程a∗x+b∗y=gcd(a,b)的一组解
k1=b/gcd;
k2=a/gcd;
k=c/gcd; //求c是gcd的几倍
printf("k1 : %d\nk2 : %d\n",k1,k2);
printf("输出方程的多组解\n");
for(int i=0;i<5;i++)
printf("%d %d\n",x*k+k1*i,y*k-k2*i); //方程a∗x+b∗y=c的多组解
return 0;
}
例题:
1352 集合计数
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<iostream>
#include<map>
using namespace std;
const int Max=1e5+10;
long long exgcd(long long a,long long b,long long &x,long long &y)
{
if(b==0)
{
x=1,y=0;
return a;
}
long long gcd=exgcd(b,a%b,x,y);
int t=y;
y=x-(a/b)*y;
x=t;
return gcd;
}
int main()
{
int T;
long long N,a,b,gcd,x,y,k,k1,k2;
scanf("%d",&T);
while(T--)
{
scanf("%lld%lld%lld",&N,&a,&b);
gcd=exgcd(a,b,x,y);
if((N+1)%gcd)
printf("0\n");
else {
k=(N+1)/gcd;
k1=b/gcd;
k2=a/gcd;
x=((x*k)%k1+k1)%k1;//计算x的最小正整数解
if(x==0)
x=k1;
y=((y*k)%k2+k2)%k2;//计算y的最下正整数解
if(y==0)
y=k2;
long long maxs_x=((N+1)-b*y)/a;// 因为(x+i*k1)*a+(y+j*k2)*b=N+1 (i,j是任意数)
maxs_x=(maxs_x-x)/k1;//这两步是计算在y最小的情况下,i最大是多少,i最小取0
printf("%lld\n",maxs_x+1);
}
}
return 0;
}
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<iostream>
#include<map>
using namespace std;
const int Max=1e5+10;
typedef long long ll;
ll exgcd(ll a,ll b,ll &xx,ll &yy)
{
if(b==0)
{
xx=1,yy=0;
return a;
}
ll gcd=exgcd(b,a%b,xx,yy);
ll t=yy;
yy=xx-(a/b)*yy;
xx=t;
return gcd;
}
int main()
{
ll xx,yy,x,y,m,n,L,a,b,c;
/*根据题意可得公式: 设k为青蛙跳的次数
(x+m*k)%L==(y+n*k)%L
-->(x+m*k)≡(y+n*k)(mod L)
--> (x+m*k)+k'L=y+n*k
--> (m-n)*k+k'*L=y-x */
scanf("%lld%lld%lld%lld%lld",&x,&y,&m,&n,&L);
if(m>n)//要保证a的值是正数
c=y-x,a=m-n,b=L;
else c=x-y,a=n-m,b=L;
ll gcd=exgcd(a,b,xx,yy);
if(c%gcd)
printf("Impossible\n");
else {
ll k=c/gcd;
ll k1=b/gcd;
xx=((xx*k)%k1+k1)%k1;//取x的最小值
printf("%lld\n",xx);
}
return 0;
}
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<iostream>
#include<map>
using namespace std;
const int Max=1e5+10;
typedef long long ll;
ll exgcd(ll a,ll b,ll &x,ll &y)
{
if(b==0)
{
x=1;
y=0;
return a;
}
ll gcd=exgcd(b,a%b,x,y);
ll t=y;
y=x-(a/b)*y;
x=t;
return gcd;
}
int main()
{
ll T,n,B,a,b,c,x,y,k,k1;
/*
因为 n=A%9973;
所以 n=A-[A/9973]*9973;
设 x=A/B;y=[A/9973]; 所以 A=B*x;
n=B*x-9973*y;
*/
scanf("%d",&T);
while(T--)
{
scanf("%lld%lld",&n,&B);
a=B;b=9973;c=n;
ll gcd=exgcd(a,b,x,y);
k=c/gcd;
k1=b/gcd;
x=((x*k)%k1+k1)%k1;
printf("%lld\n",x%9973);
}
return 0;
}
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<iostream>
#include<map>
using namespace std;
const int Max=1e5+10;
typedef long long ll;
ll exgcd(ll a,ll b,ll &x,ll &y)
{
if(b==0)
{
x=1,y=0;
return a;
}
ll gcd=exgcd(b,a%b,x,y);
ll t=y;
y=x-(a/b)*y;
x=t;
return gcd;
}
ll Gcd(ll a,ll b)
{
ll r;
while(b)
{
r=a%b;
a=b;
b=r;
}
return a;
}
int main()
{
long long a,b,s,x,y,k1,k2,k,gcd;
while(scanf("%lld%lld%lld",&a,&b,&s)!=EOF)
{
if(a==0&&b)
{
if(s%b)
printf("NO\n");
else printf("YES\n");
continue;
}
else if(b==0&&a)
{
if(s%a)
printf("NO\n");
else printf("YES\n");
continue;
}
else if(a==0&&b==0)
{
if(s)
printf("NO\n");
else printf("YES\n");
continue;
}
else if(a>s||b>s)
{
printf("NO\n");
continue;
}
gcd=exgcd(a,b,x,y);
if(s%gcd)
printf("NO\n");
else {
k=s/gcd;
k1=b/gcd;
k2=a/gcd;
x=(((x%k1)*(k%k1))%k1+k1)%k1;
y=(s-a*x)/b;
int flag=0;
while(y>0)
{
if(Gcd(x,y)==1)//不太懂为什么gcd(x,y)==1
{
printf("YES\n");
flag=1;break;
}
x+=k1;
y-=k2;
}
if(flag==0)
printf("NO\n");
}
}
return 0;
}