NTL(Number Theory Library)源码剖析(2)__基本运算

本文讲解NTL库中的基本运算代码。回忆下列常用类型

typedef unsigned long _ntl_limb_t;
typedef long _ntl_signed_limb_t;
typedef _ntl_limb_t *_ntl_limb_t_ptr;

_ntl_limb_t_ptr所指内存的格式详见我的源码剖析系列文章http://t.csdn.cn/EDDzK. 尽管每个数据单元_ntl_limb_t占32位,但实际只用低30位。

NTL_LIMB_MASK在我的计算机上是32位的(001111...1)_2,下面的CLIP函数即截取a的低30位。

static 
inline _ntl_limb_t CLIP(_ntl_limb_t a)
{
   return a & NTL_LIMB_MASK;
}

static 
inline _ntl_limb_t XCLIP(_ntl_limb_t a)
{
   return a & ~NTL_LIMB_MASK;
}

注意static函数不能被其它源文件调用。

static inline double 
DBL(_ntl_limb_t x)
{
   return double(x);
}

加减法

使用下列加法函数前要确保储存结果的地址有足够空间。

下面是高精度加法函数,可以认为初始时rp是0,b相当于进位标志,先把b加到rp,把ap的低n个_ntl_limb_t加到rp. 返回需要加到更高位的b.

_ntl_limb_t
_ntl_mpn_add_1 (_ntl_limb_t *rp, const _ntl_limb_t *ap, long  n, _ntl_limb_t b)
{
  long i;

  if (rp != ap) {
    i = 0;
    do
      {
	_ntl_limb_t r = ap[i] + b;
	rp[i] = CLIP(r);
	b = r >> NTL_ZZ_NBITS;
      }
    while (++i < n);

    return b;
  }
  else {
    i = 0;
    do
      {
        if (!b) return 0;
	_ntl_limb_t r = ap[i] + b;
	rp[i] = CLIP(r);
	b = r >> NTL_ZZ_NBITS;
      }
    while (++i < n);

    return b;
  }
}

下面函数把up的n个_ntl_limb_t向高位移动cnt比特位,结果保存到rp. 应该要限制cnt在0到30之间。返回溢出部分。

_ntl_limb_t
_ntl_mpn_lshift (_ntl_limb_t *rp, const _ntl_limb_t *up, long n, long cnt)
{
  _ntl_limb_t high_limb, low_limb;
  long tnc;
  _ntl_limb_t retval;

  up += n;
  rp += n;

  tnc = NTL_ZZ_NBITS - cnt;
  low_limb = *--up;
  retval = low_limb >> tnc;
  high_limb = CLIP(low_limb << cnt);

  while (--n != 0)
    {
      low_limb = *--up;
      *--rp = high_limb | (low_limb >> tnc);
      high_limb = CLIP(low_limb << cnt);
    }
  *--rp = high_limb;

  return retval;
}

另一个函数_ntl_mpn_rshift是类似的,向低位进行移位运算。

下面函数计算(ap表示的数) + (bp表示的数),n指定_ntl_limb_t的个数,结果储存到rp,返回进位值。


_ntl_limb_t
_ntl_mpn_add_n (_ntl_limb_t *rp, const  _ntl_limb_t *ap, const _ntl_limb_t *bp, long n)
{
  long i;
  _ntl_limb_t cy;

  for (i = 0, cy = 0; i < n; i++)
    {
      _ntl_limb_t sum = ap[i] + bp[i] + cy;
      rp[i] = CLIP(sum);
      cy = sum >> NTL_ZZ_NBITS;
    }
  return cy;
}

下面函数计算(ap表示的数) + (bp表示的数),结果储存到rp,an指定ap的_ntl_limb_t的个数,bn指定bp的_ntl_limb_t的个数,返回进位值。

_ntl_limb_t
_ntl_mpn_add (_ntl_limb_t *rp, const _ntl_limb_t *ap, long an, const  _ntl_limb_t *bp, long bn)
{
  _ntl_limb_t cy;

  cy = _ntl_mpn_add_n (rp, ap, bp, bn);
  if (an > bn)
    cy = _ntl_mpn_add_1 (rp + bn, ap + bn, an - bn, cy);
  return cy;
}

下面NTL_ZZ_FRADIX是2^{30},NTL_FRADIX_INV是其倒数 

#define NTL_RADIX           (1L<<NTL_NBITS)
#define NTL_FRADIX_INV  (((double) 1.0)/((double) NTL_RADIX))

#define NTL_ZZ_FRADIX ((double) (1L << NTL_NBITS))
#define NTL_ZZ_FRADIX_INV  (1.0/NTL_ZZ_FRADIX)

下面函数计算b*d+t,加到a上,进位值保存到t. 注意double转为long时是向下截取整数值。

static inline void
_ntl_addmulp(_ntl_limb_t& a, _ntl_limb_t b, _ntl_limb_t d, _ntl_limb_t& t) 
{ 
   _ntl_limb_t t1 = b * d; 
   _ntl_limb_t t2 = _ntl_signed_limb_t( DBL(b)*(DBL(d)*NTL_ZZ_FRADIX_INV) ) - 1; 
   t2 = t2 + ( (t1 - (t2 << NTL_ZZ_NBITS)) >> NTL_ZZ_NBITS ); 
   t1 = CLIP(t1) + a + t;
   t = t2 + (t1 >> NTL_ZZ_NBITS); 
   a = CLIP(t1);
}

下面函数计算b*d + t

// (t, a) = b*d + t
static inline void
_ntl_mulp(_ntl_limb_t& a, _ntl_limb_t b, _ntl_limb_t d, _ntl_limb_t& t) 
{ 
   _ntl_limb_t t1 =  b*d + t; 
   _ntl_limb_t t2 = _ntl_signed_limb_t( DBL(b)*(DBL(d)*NTL_ZZ_FRADIX_INV) ) - 1; 
   t = t2 + ((t1 - (t2 << NTL_ZZ_NBITS)) >> NTL_ZZ_NBITS); 
   a = CLIP(t1);
}

类似这样的辅助运算函数还有很多。特别注意下面函数中NTL_ZZ_RADIX=2^{30},所以运算中可以认为b*NTL_ZZ_RADIX = 0. NTL_ZZ_RADIX - d类似于取补码,实际这里考虑d的低30位的补码。

// (t, a) = b*(-d) + a + t, where t is "signed"
static inline void
_ntl_submulp(_ntl_limb_t& a, _ntl_limb_t b, _ntl_limb_t d, _ntl_limb_t& t) 
{ 
   _ntl_limb_t t1 =  b*(NTL_ZZ_RADIX-d) + a;
   _ntl_limb_t t2 = _ntl_signed_limb_t( DBL(b)*(DBL(NTL_ZZ_RADIX-d)*NTL_ZZ_FRADIX_INV) ) - 1; 
   _ntl_limb_t lo = CLIP(t1);
   _ntl_limb_t hi = t2 + ((t1 - (t2 << NTL_ZZ_NBITS)) >> NTL_ZZ_NBITS); 
   lo += t;
   a = CLIP(lo);
   t = hi - b - (lo >> (NTL_BITS_PER_LIMB_T-1));
}

注意最后t的赋值语句中(lo >> (NTL_BITS_PER_LIMB_T-1))总是0,可以删去。

下面函数_ntl_mpn_mul_1就是计算(up指向的大整数)*vl,保存到rp,返回进位值。


_ntl_limb_t
_ntl_mpn_mul_1 (_ntl_limb_t* rp, const _ntl_limb_t* up, long n, _ntl_limb_t vl) 
{
   _ntl_limb_t carry = 0;
   for (long i = 0; i < n; i++) 
      _ntl_mulp(rp[i], up[i], vl, carry);
   return carry;
}


_ntl_limb_t
_ntl_mpn_addmul_1 (_ntl_limb_t* rp, const _ntl_limb_t* up, long n, _ntl_limb_t vl)
{
   _ntl_limb_t carry = 0;
   for (long i = 0; i < n; i++) 
      _ntl_addmulp(rp[i], up[i], vl, carry);
   return carry;
}

_ntl_mpn_addmul_1计算up*vl,加到rp.

下面函数功能是移位减法,即计算

(carry, rp[n-1], ..., rp[0]) = (rp[n-1], ..., rp[0], shift_in)  - (up[n-1], ..., up[0]) * vl

条件编译语句提供2种实现方式,第一种容易理解,第二种相当于把for循环拆分。拆分循环后能加速的原因在于_ntl_swap函数中交换2个变量要用3个赋值语句(构造函数也视为赋值语句)。而下面这种拆分,_ntl_submulp之后只需2个赋值语句。

// compute (carry, rp[n-1], ..., rp[0]) = (rp[n-1], ..., rp[0], shift_in) 
//                                          - (up[n-1], ..., up[0]) * vl,
// and return carry (which may be negative, but stored as an unsigned).  No
// aliasing is allowed.  This is a special-purpose used by the tdiv_qr routine,
// to avoid allocating extra buffer space and extra shifting.  It is not a part
// of GMP interface.

_ntl_limb_t
_ntl_mpn_shift_submul_1(_ntl_limb_t* NTL_RESTRICT rp, _ntl_limb_t shift_in, const _ntl_limb_t* NTL_RESTRICT up, long n, _ntl_limb_t vl)
{
#if 0
   _ntl_limb_t carry = 0;
   for (long i = 0; i < n; i++) {
      _ntl_submulp(shift_in, up[i], vl, carry);
      _ntl_swap(shift_in, rp[i]);
   }

   return carry + shift_in;
#else
   // NOTE: loop unrolling seems to help a little bit
   _ntl_limb_t carry = 0;
   long i = 0;
   for (; i <= n-4; i += 4) {
      _ntl_submulp(shift_in, up[i], vl, carry);
      _ntl_limb_t tmp1 = rp[i];
      rp[i] = shift_in;

      _ntl_submulp(tmp1, up[i+1], vl, carry);
      _ntl_limb_t tmp2 = rp[i+1];
      rp[i+1] = tmp1;

      _ntl_submulp(tmp2, up[i+2], vl, carry);
      _ntl_limb_t tmp3 = rp[i+2];
      rp[i+2] = tmp2;

      _ntl_submulp(tmp3, up[i+3], vl, carry);
      shift_in = rp[i+3];
      rp[i+3] = tmp3;
   }

   for (; i < n; i++) {
      _ntl_submulp(shift_in, up[i], vl, carry);
      _ntl_swap(shift_in, rp[i]);
   }

   return carry + shift_in;
#endif
}

下面函数_ntl_addmulpsq计算a[j]*a[j] + t保存到t,进位值保存到carry. _ntl_addmulsq用b[0]去乘b的高位,结果加上carry加到a中。_ntl_mpn_base_sqr计算c = a^2,sa控制单元数。注意计算平方时,遇到i \not= j,a[i]*a[j] + a[j]*a[i]可以只算一项,然后乘2. 

void 
_ntl_addmulsq(long n, _ntl_limb_t *a, const _ntl_limb_t *b)
{
   _ntl_limb_t s = b[0];
   _ntl_limb_t carry = 0;
   for (long i = 0; i < n; i++) {
      _ntl_addmulp(a[i], b[i+1], s, carry);
   }
   a[n] += carry;
}

static inline void
_ntl_mpn_base_sqr(_ntl_limb_t *c, const _ntl_limb_t *a, long sa)
{
   long sc = 2*sa;

   for (long i = 0; i < sc; i++) c[i] = 0;

   _ntl_limb_t carry = 0;
   for (long i = 0, j = 0; j < sa; i += 2, j++) {
      _ntl_limb_t uc, t; 
       uc = carry + (c[i] << 1);
       t = CLIP(uc);
       _ntl_addmulpsq(t, a[j], carry);
       c[i] = t;
       _ntl_addmulsq(sa-j-1, c+i+1, a+j);
       uc =  (uc >> NTL_ZZ_NBITS) + (c[i+1] << 1);
       uc += carry;
       carry = uc >> NTL_ZZ_NBITS;
       c[i+1] = CLIP(uc);
   }
}

由此观之,_ntl_addmulsq完全是为了高精度乘法而设计的辅助函数。

下面函数计算(up表示的大整数)*(vp表示的大整数),结果保存到rp.

static inline _ntl_limb_t
_ntl_mpn_base_mul (_ntl_limb_t* rp, const _ntl_limb_t* up, long un, const _ntl_limb_t* vp, long vn)
{
  rp[un] = _ntl_mpn_mul_1 (rp, up, un, vp[0]);

  while (--vn >= 1)
    {
      rp += 1, vp += 1;
      rp[un] = _ntl_mpn_addmul_1 (rp, up, un, vp[0]);
    }
  return rp[un];
}

其它运算

限于篇幅,没有把所有类似的加减乘除运算列出来。

下面函数计算(b表示的整数的低hsa个_ntl_limb_t)+(b表示的整数右移hsa个_ntl_limb_t)并且保存到T. sb控制b的数据单元数。第一次读这段代码可能对命名有疑惑,事实上hsa是half sa,在以后的函数kar_mul中用到。

static 
long kar_fold(_ntl_limb_t *T, const _ntl_limb_t *b, long sb, long hsa)
{
   _ntl_limb_t carry = 0;

   for (long i = 0; i < sb-hsa; i++) {
      _ntl_limb_t t = b[i] + b[i+hsa] + carry;
      carry = t >> NTL_ZZ_NBITS;
      T[i] = CLIP(t);
   }

   for (long i = sb-hsa; i < hsa; i++) {
      _ntl_limb_t t = b[i] + carry;
      carry = t >> NTL_ZZ_NBITS;
      T[i] = CLIP(t);
   }

   if (carry) {
      T[hsa] = carry;
      return hsa+1;
   }
   else {
      return hsa;
   }

}

下面是简单的减法,计算T-c保存到T指向的地址,这个函数使用前必须确保T>c,否则无限循环。所有这个函数在实现上缺陷很大。

static
void kar_sub(_ntl_limb_t *T, const _ntl_limb_t *c, long sc)
{
   _ntl_limb_t carry = 0;

   for (long i = 0; i < sc; i++) {
      _ntl_limb_t t = T[i] - c[i] - carry;
      carry = (t >> NTL_ZZ_NBITS) & 1;
      T[i] = CLIP(t);
   }

   for (long i = sc; carry; i++) {
      _ntl_limb_t t = T[i] - 1;
      carry = (t >> NTL_ZZ_NBITS) & 1;
      T[i] = CLIP(t);
   }
}

下面这个函数先令c += hsa,然后计算(c表示的整数)+(T表示的整数),保存到c. 这个函数的缺陷是第二个for循环可能会访问越界。

static
void kar_add(_ntl_limb_t *c, const _ntl_limb_t *T, long sT, long hsa)
{
   c += hsa;
   _ntl_limb_t carry = 0;

   while (sT > 0 && T[sT-1] == 0) sT--;

   for (long i = 0; i < sT; i++) {
      _ntl_limb_t t = c[i] + T[i] + carry;
      carry = t >> NTL_NBITS;
      c[i] = CLIP(t);
   }

   for (long i = sT; carry; i++) {
      _ntl_limb_t t = c[i] + 1;
      carry = t >> NTL_NBITS;
      c[i] = CLIP(t);
   }
}

下面函数将c表示的整数的低hsa个_ntl_limb_t置为T的低hsa个_ntl_limb_t,在高位为(c表示的数的高位)+(T表示的数的高位).

static
void kar_fix(_ntl_limb_t *c, const _ntl_limb_t *T, long sT, long hsa)
{
   for (long i = 0; i < hsa; i++) {
      c[i] = T[i];
   }

   _ntl_limb_t carry = 0;

   for (long i = hsa; i < sT; i++) {
      _ntl_limb_t t = c[i] + T[i] + carry;
      carry = t >> NTL_NBITS;
      c[i] = CLIP(t);
   }

   for (long i = sT; carry; i++) {
      _ntl_limb_t t = c[i] + 1;
      carry = t >> NTL_NBITS;
      c[i] = CLIP(t);
   }
}

下面函数计算(a表示的整数)*(b表示的整数),结果保存到c. sa控制a的_ntl_limb_t位数,sb控制b的。stk是辅助栈,有sp个_ntl_limb_t的空间。不容易理解的是(hsa< sb)的情况,里面中了递归的做法,降低了时间复杂度,最后T3 = (a的高位)*(b的低位)+ (a的低位)*(b的高位)。关于时间复杂度的分析见代码后的内容。


#define KARX (16)

static
void kar_mul(_ntl_limb_t *c, const _ntl_limb_t *a, long sa, 
             const _ntl_limb_t *b, long sb, _ntl_limb_t *stk, long sp)
{

   if (sa < sb) {
      _ntl_swap(a, b);
      _ntl_swap(sa, sb);
   }

   if (sb < KARX) {
      /* classic algorithm */
 
      _ntl_mpn_base_mul(c, a, sa, b, sb);

   }
   else {
      long hsa = (sa + 1) >> 1;

      if (hsa < sb) {
         /* normal case */

         _ntl_limb_t *T1, *T2, *T3;

         /* allocate space */

         sp -= (hsa + 1) + ((hsa << 1) + 2);
         if (sp < 0) TerminalError("internal error: kmem overflow");

         T1 = c;
         T2 = stk;  stk += hsa + 1;  
         T3 = stk;  stk += (hsa << 1) + 2; 

         /* compute T1 = a_lo + a_hi */
         long sT1 = kar_fold(T1, a, sa, hsa);

         /* compute T2 = b_lo + b_hi */
         long sT2 = kar_fold(T2, b, sb, hsa);

         
         /* recursively compute T3 = T1 * T2 */
         kar_mul(T3, T1, sT1, T2, sT2, stk, sp);

         /* recursively compute a_hi * b_hi into high part of c */
         /* and subtract from T3 */
         kar_mul(c + (hsa << 1), a+hsa, sa-hsa, b+hsa, sb-hsa, stk, sp);
         kar_sub(T3, c + (hsa << 1), sa+sb-2*hsa);

         /* recursively compute a_lo*b_lo into low part of c */
         /* and subtract from T3 */
         kar_mul(c, a, hsa, b, hsa, stk, sp);
         kar_sub(T3, c, 2*hsa);

         /* finally, add T3 * NTL_RADIX^{hsa} to c */
         kar_add(c, T3, sT1+sT2, hsa);
      }
      else {
         /* degenerate case */

         _ntl_limb_t *T;
         
    
         sp -= (sb + hsa);
         if (sp < 0) TerminalError("internal error: kmem overflow");
 
         T = stk;  stk += sb + hsa;

         /* recursively compute b*a_hi into high part of c */
         kar_mul(c + hsa, a+hsa, sa-hsa, b, sb, stk, sp);

         /* recursively compute b*a_lo into T */
         kar_mul(T, a, hsa, b, sb, stk, sp);

         /* fix-up result */
         kar_fix(c, T, hsa+sb, hsa);
      }
   }
}

现在分析kar_mul的时间复杂度,设其时间复杂度为T(sa, sb)。由代码起始的swap技巧可以假设sa >= sb,进一步我们分析时间复杂度的上界,所以不妨设sa = sb,从而可以设T(sa) = T(sa,sa),即双参数数列化为单参数。 kar_mul最耗时间的情况是hsa < sb的情况,根据代码得到以下递推式

T(sa) = sa + T(\frac{sa}{2}) + T(\frac{sa}{2}) + sa + T(\frac{sa}{2}) + sa + sa + sa

 T(sa) =5 sa + 3 T(\frac{sa}{2})

再由算法分析的Master Theorem知T(sa) = O(sa^{log_2 3}). 再考虑一般的情况,就有

T(sa, sb) = O( (\max \{sa, sb\})^{log_2 3})

因此这里用技巧确实降低了乘法的时间复杂度。

浮点数

除了大整数,NTL还提供浮点数。_ntl_limb_t是32位数,有效数据有30位,可以解释成整数或浮点数。DBL函数返回解释为double的x.

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值