中国国剩余定理,又名孙子定理,是中国古代求解一次同余式组的方法。当前可用于解决密码学中的秘密共享问题,也可以用来解决布隆过滤器的false positive问题。
1 中国剩余定理
如果,有如下一元线性同余方程组
(
p
i
为
素
数
)
(p_i为素数)
(pi为素数)
(1)
{
x
≡
a
1
(
m
o
d
  
p
1
)
x
≡
a
2
(
m
o
d
  
p
2
)
.
.
.
x
≡
a
n
(
m
o
d
  
p
n
)
\left\{ \begin{matrix} x\equiv a_1 (\mod p_1) \\ x\equiv a_2 (\mod p_2) \\ . \\ . \\ . \\ x\equiv a_n (\mod p_n) \end{matrix} \right . \tag{1}
⎩⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎧x≡a1(modp1)x≡a2(modp2)...x≡an(modpn)(1)
方程的解为
(2)
x
=
(
∑
i
=
1
n
a
i
t
i
M
i
)
+
k
M
x=(\sum_{i=1}^na_it_iM_i) + kM \tag{2}
x=(i=1∑naitiMi)+kM(2)
其中
M
=
∏
i
=
1
n
p
i
,
M
i
=
M
/
p
i
,
t
i
=
M
i
−
1
(
m
o
d
  
p
i
)
M=\prod_{i=1}^np_i, M_i = M/p_i, t_i = M_i^{-1}(\mod p_i)
M=∏i=1npi,Mi=M/pi,ti=Mi−1(modpi)
显然解
x
x
x满足上面的一次同余方程组,所以
x
x
x是方程组的解。接下来只需要证明同余方程组的解的周期为
M
M
M即可。
假设有两个解
x
1
,
x
2
x_1,x_2
x1,x2,那么
x
1
−
x
2
≡
0
(
m
o
d
  
p
i
)
x_1-x_2\equiv 0(\mod p_i)
x1−x2≡0(modpi),可得
x
1
−
x
2
x_1-x_2
x1−x2整除
M
M
M,所以解得周期为
M
M
M,得证.
2 扩展欧几里得算法
辗转相除法可以很方便的求解两个数的最大公约数,如果把求解过程中记录下来就可以用于求解方程
a
x
+
b
y
=
g
c
d
(
a
,
b
)
ax+by=gcd(a,b)
ax+by=gcd(a,b),这个方法可以用于求解上面涉及到的
t
i
t_i
ti,
t
i
=
M
i
−
1
(
m
o
d
  
p
i
)
t_i = M_i^{-1}(\mod p_i)
ti=Mi−1(modpi)即
(3)
t
i
×
M
i
≡
1
(
m
o
d
  
p
i
)
t_i\times M_i \equiv 1(\mod p_i) \tag{3}
ti×Mi≡1(modpi)(3)
所以只需要求解
(4)
x
M
i
+
y
p
i
=
1
=
g
c
d
(
M
i
,
p
i
)
xM_i+yp_i =1= gcd(M_i,p_i) \tag{4}
xMi+ypi=1=gcd(Mi,pi)(4)
(5)
t
i
=
x
t_i = x\tag{5}
ti=x(5)
3 参考代码
#include <stdio.h>
#include <stdlib.h>
typedef int int32;
// ax + by = gcd(a, b)
int32 exgcd(int32 a, int32 b, int32 *x, int32 *y)
{
if(b == 0) {
*x = 1;
*y = 0;
return a;
} else {
int32 r = exgcd(b, a % b, x, y);
int32 t = *x;
*x = *y;
*y = t - a / b * (*y);
return r;
}
}
int32 func(int32 *a, int32 *p, int32 len)
{
int32 i, *t;
int32 M = 1, *m, temp, result = 0;
m = (int32 *)malloc(len * sizeof(int32));
t = (int32 *)malloc(len * sizeof(int32));
for(i = 0; i < len; i++) {
m[i] = 1;
M *= p[i];
}
printf("\n\nM = %d\nMi ", M);
for(i = 0; i < len; i++) {
m[i] = M / p[i];
printf("%5d ", m[i]);
exgcd(m[i], p[i], &t[i], &temp);
}
printf("\nTi ");
for(i = 0; i < len; i++) {
result += a[i] * t[i] * m[i];
printf("%5d ", t[i]);
}
printf("\n");
while (result <= 0)
result += M;
free(m);
free(t);
return result % M;
}
int main()
{
int32 i, len = 4, result;
int32 a[10] = {1, 2, 4, 1}, p[10] = {5, 3, 7, 11};
printf("primer : ");
for(i = 0; i < len; i++) {
printf("%3d ", p[i]);
}
printf("\nremaind: ");
for(i = 0; i < len; i++) {
printf("%3d ", a[i]);
}
result = func(a, p, len);
printf("\nresult = %d\n\n", result);
system("pause");
}