本文讲解NTL库中的基本运算代码。回忆下列常用类型
typedef unsigned long _ntl_limb_t;
typedef long _ntl_signed_limb_t;
typedef _ntl_limb_t *_ntl_limb_t_ptr;
_ntl_limb_t_ptr所指内存的格式详见我的源码剖析系列文章http://t.csdn.cn/EDDzK. 尽管每个数据单元_ntl_limb_t占32位,但实际只用低30位。
NTL_LIMB_MASK在我的计算机上是32位的,下面的CLIP函数即截取a的低30位。
static
inline _ntl_limb_t CLIP(_ntl_limb_t a)
{
return a & NTL_LIMB_MASK;
}
static
inline _ntl_limb_t XCLIP(_ntl_limb_t a)
{
return a & ~NTL_LIMB_MASK;
}
注意static函数不能被其它源文件调用。
static inline double
DBL(_ntl_limb_t x)
{
return double(x);
}
加减法
使用下列加法函数前要确保储存结果的地址有足够空间。
下面是高精度加法函数,可以认为初始时rp是0,b相当于进位标志,先把b加到rp,把ap的低n个_ntl_limb_t加到rp. 返回需要加到更高位的b.
_ntl_limb_t
_ntl_mpn_add_1 (_ntl_limb_t *rp, const _ntl_limb_t *ap, long n, _ntl_limb_t b)
{
long i;
if (rp != ap) {
i = 0;
do
{
_ntl_limb_t r = ap[i] + b;
rp[i] = CLIP(r);
b = r >> NTL_ZZ_NBITS;
}
while (++i < n);
return b;
}
else {
i = 0;
do
{
if (!b) return 0;
_ntl_limb_t r = ap[i] + b;
rp[i] = CLIP(r);
b = r >> NTL_ZZ_NBITS;
}
while (++i < n);
return b;
}
}
下面函数把up的n个_ntl_limb_t向高位移动cnt比特位,结果保存到rp. 应该要限制cnt在0到30之间。返回溢出部分。
_ntl_limb_t
_ntl_mpn_lshift (_ntl_limb_t *rp, const _ntl_limb_t *up, long n, long cnt)
{
_ntl_limb_t high_limb, low_limb;
long tnc;
_ntl_limb_t retval;
up += n;
rp += n;
tnc = NTL_ZZ_NBITS - cnt;
low_limb = *--up;
retval = low_limb >> tnc;
high_limb = CLIP(low_limb << cnt);
while (--n != 0)
{
low_limb = *--up;
*--rp = high_limb | (low_limb >> tnc);
high_limb = CLIP(low_limb << cnt);
}
*--rp = high_limb;
return retval;
}
另一个函数_ntl_mpn_rshift是类似的,向低位进行移位运算。
下面函数计算(ap表示的数) + (bp表示的数),n指定_ntl_limb_t的个数,结果储存到rp,返回进位值。
_ntl_limb_t
_ntl_mpn_add_n (_ntl_limb_t *rp, const _ntl_limb_t *ap, const _ntl_limb_t *bp, long n)
{
long i;
_ntl_limb_t cy;
for (i = 0, cy = 0; i < n; i++)
{
_ntl_limb_t sum = ap[i] + bp[i] + cy;
rp[i] = CLIP(sum);
cy = sum >> NTL_ZZ_NBITS;
}
return cy;
}
下面函数计算(ap表示的数) + (bp表示的数),结果储存到rp,an指定ap的_ntl_limb_t的个数,bn指定bp的_ntl_limb_t的个数,返回进位值。
_ntl_limb_t
_ntl_mpn_add (_ntl_limb_t *rp, const _ntl_limb_t *ap, long an, const _ntl_limb_t *bp, long bn)
{
_ntl_limb_t cy;
cy = _ntl_mpn_add_n (rp, ap, bp, bn);
if (an > bn)
cy = _ntl_mpn_add_1 (rp + bn, ap + bn, an - bn, cy);
return cy;
}
下面NTL_ZZ_FRADIX是,NTL_FRADIX_INV是其倒数
#define NTL_RADIX (1L<<NTL_NBITS)
#define NTL_FRADIX_INV (((double) 1.0)/((double) NTL_RADIX))
#define NTL_ZZ_FRADIX ((double) (1L << NTL_NBITS))
#define NTL_ZZ_FRADIX_INV (1.0/NTL_ZZ_FRADIX)
下面函数计算b*d+t,加到a上,进位值保存到t. 注意double转为long时是向下截取整数值。
static inline void
_ntl_addmulp(_ntl_limb_t& a, _ntl_limb_t b, _ntl_limb_t d, _ntl_limb_t& t)
{
_ntl_limb_t t1 = b * d;
_ntl_limb_t t2 = _ntl_signed_limb_t( DBL(b)*(DBL(d)*NTL_ZZ_FRADIX_INV) ) - 1;
t2 = t2 + ( (t1 - (t2 << NTL_ZZ_NBITS)) >> NTL_ZZ_NBITS );
t1 = CLIP(t1) + a + t;
t = t2 + (t1 >> NTL_ZZ_NBITS);
a = CLIP(t1);
}
下面函数计算b*d + t
// (t, a) = b*d + t
static inline void
_ntl_mulp(_ntl_limb_t& a, _ntl_limb_t b, _ntl_limb_t d, _ntl_limb_t& t)
{
_ntl_limb_t t1 = b*d + t;
_ntl_limb_t t2 = _ntl_signed_limb_t( DBL(b)*(DBL(d)*NTL_ZZ_FRADIX_INV) ) - 1;
t = t2 + ((t1 - (t2 << NTL_ZZ_NBITS)) >> NTL_ZZ_NBITS);
a = CLIP(t1);
}
类似这样的辅助运算函数还有很多。特别注意下面函数中NTL_ZZ_RADIX=,所以运算中可以认为b*NTL_ZZ_RADIX = 0. NTL_ZZ_RADIX - d类似于取补码,实际这里考虑d的低30位的补码。
// (t, a) = b*(-d) + a + t, where t is "signed"
static inline void
_ntl_submulp(_ntl_limb_t& a, _ntl_limb_t b, _ntl_limb_t d, _ntl_limb_t& t)
{
_ntl_limb_t t1 = b*(NTL_ZZ_RADIX-d) + a;
_ntl_limb_t t2 = _ntl_signed_limb_t( DBL(b)*(DBL(NTL_ZZ_RADIX-d)*NTL_ZZ_FRADIX_INV) ) - 1;
_ntl_limb_t lo = CLIP(t1);
_ntl_limb_t hi = t2 + ((t1 - (t2 << NTL_ZZ_NBITS)) >> NTL_ZZ_NBITS);
lo += t;
a = CLIP(lo);
t = hi - b - (lo >> (NTL_BITS_PER_LIMB_T-1));
}
注意最后t的赋值语句中(lo >> (NTL_BITS_PER_LIMB_T-1))总是0,可以删去。
下面函数_ntl_mpn_mul_1就是计算(up指向的大整数)*vl,保存到rp,返回进位值。
_ntl_limb_t
_ntl_mpn_mul_1 (_ntl_limb_t* rp, const _ntl_limb_t* up, long n, _ntl_limb_t vl)
{
_ntl_limb_t carry = 0;
for (long i = 0; i < n; i++)
_ntl_mulp(rp[i], up[i], vl, carry);
return carry;
}
_ntl_limb_t
_ntl_mpn_addmul_1 (_ntl_limb_t* rp, const _ntl_limb_t* up, long n, _ntl_limb_t vl)
{
_ntl_limb_t carry = 0;
for (long i = 0; i < n; i++)
_ntl_addmulp(rp[i], up[i], vl, carry);
return carry;
}
_ntl_mpn_addmul_1计算up*vl,加到rp.
下面函数功能是移位减法,即计算
(carry, rp[n-1], ..., rp[0]) = (rp[n-1], ..., rp[0], shift_in) - (up[n-1], ..., up[0]) * vl
条件编译语句提供2种实现方式,第一种容易理解,第二种相当于把for循环拆分。拆分循环后能加速的原因在于_ntl_swap函数中交换2个变量要用3个赋值语句(构造函数也视为赋值语句)。而下面这种拆分,_ntl_submulp之后只需2个赋值语句。
// compute (carry, rp[n-1], ..., rp[0]) = (rp[n-1], ..., rp[0], shift_in)
// - (up[n-1], ..., up[0]) * vl,
// and return carry (which may be negative, but stored as an unsigned). No
// aliasing is allowed. This is a special-purpose used by the tdiv_qr routine,
// to avoid allocating extra buffer space and extra shifting. It is not a part
// of GMP interface.
_ntl_limb_t
_ntl_mpn_shift_submul_1(_ntl_limb_t* NTL_RESTRICT rp, _ntl_limb_t shift_in, const _ntl_limb_t* NTL_RESTRICT up, long n, _ntl_limb_t vl)
{
#if 0
_ntl_limb_t carry = 0;
for (long i = 0; i < n; i++) {
_ntl_submulp(shift_in, up[i], vl, carry);
_ntl_swap(shift_in, rp[i]);
}
return carry + shift_in;
#else
// NOTE: loop unrolling seems to help a little bit
_ntl_limb_t carry = 0;
long i = 0;
for (; i <= n-4; i += 4) {
_ntl_submulp(shift_in, up[i], vl, carry);
_ntl_limb_t tmp1 = rp[i];
rp[i] = shift_in;
_ntl_submulp(tmp1, up[i+1], vl, carry);
_ntl_limb_t tmp2 = rp[i+1];
rp[i+1] = tmp1;
_ntl_submulp(tmp2, up[i+2], vl, carry);
_ntl_limb_t tmp3 = rp[i+2];
rp[i+2] = tmp2;
_ntl_submulp(tmp3, up[i+3], vl, carry);
shift_in = rp[i+3];
rp[i+3] = tmp3;
}
for (; i < n; i++) {
_ntl_submulp(shift_in, up[i], vl, carry);
_ntl_swap(shift_in, rp[i]);
}
return carry + shift_in;
#endif
}
下面函数_ntl_addmulpsq计算a[j]*a[j] + t保存到t,进位值保存到carry. _ntl_addmulsq用b[0]去乘b的高位,结果加上carry加到a中。_ntl_mpn_base_sqr计算,sa控制单元数。注意计算平方时,遇到,a[i]*a[j] + a[j]*a[i]可以只算一项,然后乘2.
void
_ntl_addmulsq(long n, _ntl_limb_t *a, const _ntl_limb_t *b)
{
_ntl_limb_t s = b[0];
_ntl_limb_t carry = 0;
for (long i = 0; i < n; i++) {
_ntl_addmulp(a[i], b[i+1], s, carry);
}
a[n] += carry;
}
static inline void
_ntl_mpn_base_sqr(_ntl_limb_t *c, const _ntl_limb_t *a, long sa)
{
long sc = 2*sa;
for (long i = 0; i < sc; i++) c[i] = 0;
_ntl_limb_t carry = 0;
for (long i = 0, j = 0; j < sa; i += 2, j++) {
_ntl_limb_t uc, t;
uc = carry + (c[i] << 1);
t = CLIP(uc);
_ntl_addmulpsq(t, a[j], carry);
c[i] = t;
_ntl_addmulsq(sa-j-1, c+i+1, a+j);
uc = (uc >> NTL_ZZ_NBITS) + (c[i+1] << 1);
uc += carry;
carry = uc >> NTL_ZZ_NBITS;
c[i+1] = CLIP(uc);
}
}
由此观之,_ntl_addmulsq完全是为了高精度乘法而设计的辅助函数。
下面函数计算(up表示的大整数)*(vp表示的大整数),结果保存到rp.
static inline _ntl_limb_t
_ntl_mpn_base_mul (_ntl_limb_t* rp, const _ntl_limb_t* up, long un, const _ntl_limb_t* vp, long vn)
{
rp[un] = _ntl_mpn_mul_1 (rp, up, un, vp[0]);
while (--vn >= 1)
{
rp += 1, vp += 1;
rp[un] = _ntl_mpn_addmul_1 (rp, up, un, vp[0]);
}
return rp[un];
}
其它运算
限于篇幅,没有把所有类似的加减乘除运算列出来。
下面函数计算(b表示的整数的低hsa个_ntl_limb_t)+(b表示的整数右移hsa个_ntl_limb_t)并且保存到T. sb控制b的数据单元数。第一次读这段代码可能对命名有疑惑,事实上hsa是half sa,在以后的函数kar_mul中用到。
static
long kar_fold(_ntl_limb_t *T, const _ntl_limb_t *b, long sb, long hsa)
{
_ntl_limb_t carry = 0;
for (long i = 0; i < sb-hsa; i++) {
_ntl_limb_t t = b[i] + b[i+hsa] + carry;
carry = t >> NTL_ZZ_NBITS;
T[i] = CLIP(t);
}
for (long i = sb-hsa; i < hsa; i++) {
_ntl_limb_t t = b[i] + carry;
carry = t >> NTL_ZZ_NBITS;
T[i] = CLIP(t);
}
if (carry) {
T[hsa] = carry;
return hsa+1;
}
else {
return hsa;
}
}
下面是简单的减法,计算T-c保存到T指向的地址,这个函数使用前必须确保T>c,否则无限循环。所有这个函数在实现上缺陷很大。
static
void kar_sub(_ntl_limb_t *T, const _ntl_limb_t *c, long sc)
{
_ntl_limb_t carry = 0;
for (long i = 0; i < sc; i++) {
_ntl_limb_t t = T[i] - c[i] - carry;
carry = (t >> NTL_ZZ_NBITS) & 1;
T[i] = CLIP(t);
}
for (long i = sc; carry; i++) {
_ntl_limb_t t = T[i] - 1;
carry = (t >> NTL_ZZ_NBITS) & 1;
T[i] = CLIP(t);
}
}
下面这个函数先令c += hsa,然后计算(c表示的整数)+(T表示的整数),保存到c. 这个函数的缺陷是第二个for循环可能会访问越界。
static
void kar_add(_ntl_limb_t *c, const _ntl_limb_t *T, long sT, long hsa)
{
c += hsa;
_ntl_limb_t carry = 0;
while (sT > 0 && T[sT-1] == 0) sT--;
for (long i = 0; i < sT; i++) {
_ntl_limb_t t = c[i] + T[i] + carry;
carry = t >> NTL_NBITS;
c[i] = CLIP(t);
}
for (long i = sT; carry; i++) {
_ntl_limb_t t = c[i] + 1;
carry = t >> NTL_NBITS;
c[i] = CLIP(t);
}
}
下面函数将c表示的整数的低hsa个_ntl_limb_t置为T的低hsa个_ntl_limb_t,在高位为(c表示的数的高位)+(T表示的数的高位).
static
void kar_fix(_ntl_limb_t *c, const _ntl_limb_t *T, long sT, long hsa)
{
for (long i = 0; i < hsa; i++) {
c[i] = T[i];
}
_ntl_limb_t carry = 0;
for (long i = hsa; i < sT; i++) {
_ntl_limb_t t = c[i] + T[i] + carry;
carry = t >> NTL_NBITS;
c[i] = CLIP(t);
}
for (long i = sT; carry; i++) {
_ntl_limb_t t = c[i] + 1;
carry = t >> NTL_NBITS;
c[i] = CLIP(t);
}
}
下面函数计算(a表示的整数)*(b表示的整数),结果保存到c. sa控制a的_ntl_limb_t位数,sb控制b的。stk是辅助栈,有sp个_ntl_limb_t的空间。不容易理解的是(hsa< sb)的情况,里面中了递归的做法,降低了时间复杂度,最后T3 = (a的高位)*(b的低位)+ (a的低位)*(b的高位)。关于时间复杂度的分析见代码后的内容。
#define KARX (16)
static
void kar_mul(_ntl_limb_t *c, const _ntl_limb_t *a, long sa,
const _ntl_limb_t *b, long sb, _ntl_limb_t *stk, long sp)
{
if (sa < sb) {
_ntl_swap(a, b);
_ntl_swap(sa, sb);
}
if (sb < KARX) {
/* classic algorithm */
_ntl_mpn_base_mul(c, a, sa, b, sb);
}
else {
long hsa = (sa + 1) >> 1;
if (hsa < sb) {
/* normal case */
_ntl_limb_t *T1, *T2, *T3;
/* allocate space */
sp -= (hsa + 1) + ((hsa << 1) + 2);
if (sp < 0) TerminalError("internal error: kmem overflow");
T1 = c;
T2 = stk; stk += hsa + 1;
T3 = stk; stk += (hsa << 1) + 2;
/* compute T1 = a_lo + a_hi */
long sT1 = kar_fold(T1, a, sa, hsa);
/* compute T2 = b_lo + b_hi */
long sT2 = kar_fold(T2, b, sb, hsa);
/* recursively compute T3 = T1 * T2 */
kar_mul(T3, T1, sT1, T2, sT2, stk, sp);
/* recursively compute a_hi * b_hi into high part of c */
/* and subtract from T3 */
kar_mul(c + (hsa << 1), a+hsa, sa-hsa, b+hsa, sb-hsa, stk, sp);
kar_sub(T3, c + (hsa << 1), sa+sb-2*hsa);
/* recursively compute a_lo*b_lo into low part of c */
/* and subtract from T3 */
kar_mul(c, a, hsa, b, hsa, stk, sp);
kar_sub(T3, c, 2*hsa);
/* finally, add T3 * NTL_RADIX^{hsa} to c */
kar_add(c, T3, sT1+sT2, hsa);
}
else {
/* degenerate case */
_ntl_limb_t *T;
sp -= (sb + hsa);
if (sp < 0) TerminalError("internal error: kmem overflow");
T = stk; stk += sb + hsa;
/* recursively compute b*a_hi into high part of c */
kar_mul(c + hsa, a+hsa, sa-hsa, b, sb, stk, sp);
/* recursively compute b*a_lo into T */
kar_mul(T, a, hsa, b, sb, stk, sp);
/* fix-up result */
kar_fix(c, T, hsa+sb, hsa);
}
}
}
现在分析kar_mul的时间复杂度,设其时间复杂度为T(sa, sb)。由代码起始的swap技巧可以假设sa >= sb,进一步我们分析时间复杂度的上界,所以不妨设sa = sb,从而可以设T(sa) = T(sa,sa),即双参数数列化为单参数。 kar_mul最耗时间的情况是hsa < sb的情况,根据代码得到以下递推式
即
再由算法分析的Master Theorem知T(sa) = . 再考虑一般的情况,就有
因此这里用技巧确实降低了乘法的时间复杂度。
浮点数
除了大整数,NTL还提供浮点数。_ntl_limb_t是32位数,有效数据有30位,可以解释成整数或浮点数。DBL函数返回解释为double的x.