NTT数论变换

下一篇:https://blog.csdn.net/wxkhturfun/article/details/111186205

1.参考链接

https://www.cnblogs.com/zarth/p/7288456.html
参考链接中给出的是c++代码,为了从头到尾理解一遍,在下用Python实现了一遍
需要注意的有:

  1. 整体代码与FFT(快速傅里叶变换)基本相仿
  2. 求逆的时候,即Fntt(a, n, -1)中的第三个参数为-1时,我们是乘上一个参数Inv,原理如下图:
    图片截自:https://zhuanlan.zhihu.com/p/80297169
    上图戴自:https://zhuanlan.zhihu.com/p/80297169
  3. 原作者用C字符串来处理输入,这里用Python列表来处理,差别在于最后一个字符’\0’,不过这没有影响,虽说一度纠结于此而影响改写
  4. 理论上python3的int型是无限长的(内存足够),但是还是得使用quick_mod函数,否则还是报错:溢出
  5. 如要想弄清所有的具体的蝶形运算原理、各个步骤的实现意义:https://zhuanlan.zhihu.com/p/80297169
    https://blog.csdn.net/GGN_2015/article/details/68922404
    6.阅读前,先阅读:链接…………从FTT到NTT

2.Python3 代码(魔改自C++)

给出与源码不同的另一组N P G NUM:

N = 256
P = 7681
G = 17
NUM = 9
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:ys
def quick_mod(a: int, b: int,m: int):
     ans = 1
     a %= m
     while(b):
          if(b & 1):
               ans = ans * a % m
               b-=1
          b >>= 1
          a = a * a % m
     return ans
def GetWn(NUM:int ,wn:list):
    for i in range(NUM):
        t = 1 << i
        wn[i] = quick_mod(G, int((P - 1) / t), P)
def Prepare(A:list, B:list, a:list ,b:list, length=1 ):
     length_A = len(A)
     length_B = len(B)
     while((length <= 2 * length_A) or (length <= 2 * length_B)):
          length <<= 1
     A.extend( [0]*(length - length_A) ) 
     B.extend( [0]*(length - length_B) ) 
     a.extend([0]*length)
     b.extend([0]*length)
     for i in range(length_A):
          A[length - 1 - i] = A[length_A - 1 - i]
     for i in range(length - length_A):
        A[i] = 0
     for i in range(length_B):
          B[length - 1 - i] = B[length_B - 1 - i]
     for i in range(length-length_B):
        B[i] = 0
     for i in range(length):
          a[length - 1 - i] = A[i]
     for i in range(length):
          b[length - 1 - i] = B[i]
     return length
def  Rader(a:list, length:int):
     j = length >> 1
     for i in range(1,length-1):
          if(i < j): 
               a[i], a[j] = a[j], a[i] 
          k = length >> 1
          while(j >= k):
               j -= k
               k >>= 1
          if(j < k):
               j += k
     #return a
def Fntt(a:list ,length:int ,on:int ):

     #a=Rader(a, length)
     Rader(a, length)
     id = 0
     h = 2
     while (h <= length): 
          id += 1
          for j in range(0,length,h):
               w = 1
               for k in range(j,j+int(h/2)):
                    u = a[k] % P
                    t = w * (a[k + int(h / 2)] % P) % P
                    a[k] = (u + t) % P
                    a[k + int(h / 2)] = ((u - t) % P + P) % P
                    w = w * wn[id] % P 
          h <<= 1             
     if(on == -1):
          for i in range(1,int(length/2)):
               a[i], a[length - i] = a[length - i],a[i]
          Inv = quick_mod(length, P - 2, P)
          for i in range(length):
               a[i] = a[i] % P * Inv % P
 
def Conv(a:list ,b:list ,n:int ):
     Fntt(a, n, 1)
     Fntt(b, n, 1)
     for i in range(n):
          a[i] = a[i] * b[i] % P
     Fntt(a, n, -1)
     
def Transfer(a:list ,n:list ):
     t = 0
     for i in range(n):
          a[i] += t
          if(a[i] > 9):
               t = int(a[i] / 10)
               a[i] %= 10
          else:
               t = 0
def Print (a:list, n:int):
     flag = True
     for i in range(n-1,-1,-1):
          if(a[i] != 0  and  flag):
               print( a[i],end="")
               flag = False
          elif( not flag):
               print(a[i],end="")
     print("")#这里的print是用于换行的需要,因为前面都是end=""

N = 1<<18
P = (479<<21) + 1
G = 3
NUM = 20
wn=[0]*NUM
GetWn(NUM,wn)
if __name__ == "__main__":
     nums = input("input:").split()
     A=[]
     B=[]
     for i in nums[0] :
          A.append(int(i))
     for i in nums[1] :
          B.append(int(i))
     a=[]
     b=[]
     length=Prepare(A, B,a,b,length=1)
     Conv(a,b,length)
     Transfer(a, length)
     Print(a,length)       

3.简单测试

  1. 生成测试文件
import random
with open("fntt.txt","w") as f:
    for i in range(10000):
        A=random.randint(1000,9999)
        B=random.randint(1000,9999)
        result=A*B
        f.write(str(A)+"\t"+str(B)+"\t"+str(result)+"\n")

w代表覆盖,将“w”改为“a”则代表追加

  1. 测试
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:ys
def quick_mod(a: int, b: int,m: int):
     ans = 1
     a %= m
     while(b):
          if(b & 1):
               ans = ans * a % m
               b-=1
          b >>= 1
          a = a * a % m
     return ans
def GetWn(NUM:int ,wn:list):
    for i in range(NUM):
        t = 1 << i
        wn[i] = quick_mod(G, int((P - 1) / t), P)
def Prepare(A:list, B:list, a:list ,b:list, length=1 ):
     length_A = len(A)
     length_B = len(B)
     while((length <= 2 * length_A) or (length <= 2 * length_B)):
          length <<= 1
     A.extend( [0]*(length - length_A) ) 
     B.extend( [0]*(length - length_B) ) 
     a.extend([0]*length)
     b.extend([0]*length)
     for i in range(length_A):
          A[length - 1 - i] = A[length_A - 1 - i]
     for i in range(length - length_A):
        A[i] = 0
     for i in range(length_B):
          B[length - 1 - i] = B[length_B - 1 - i]
     for i in range(length-length_B):
        B[i] = 0
     for i in range(length):
          a[length - 1 - i] = A[i]
     for i in range(length):
          b[length - 1 - i] = B[i]
     return length
def  Rader(a:list, length:int):
     j = length >> 1
     for i in range(1,length-1):
          if(i < j): 
               a[i], a[j] = a[j], a[i] 
          k = length >> 1
          while(j >= k):
               j -= k
               k >>= 1
          if(j < k):
               j += k
     #return a
def Fntt(a:list ,length:int ,on:int ):

     #a=Rader(a, length)
     Rader(a, length)
     id = 0
     h = 2
     while (h <= length): 
          id += 1
          for j in range(0,length,h):
               w = 1
               for k in range(j,j+int(h/2)):
                    u = a[k] % P
                    t = w * (a[k + int(h / 2)] % P) % P
                    a[k] = (u + t) % P
                    a[k + int(h / 2)] = ((u - t) % P + P) % P
                    w = w * wn[id] % P 
          h <<= 1             
     if(on == -1):
          for i in range(1,int(length/2)):
               a[i], a[length - i] = a[length - i],a[i]
          Inv = quick_mod(length, P - 2, P)
          for i in range(length):
               a[i] = a[i] % P * Inv % P
 
def Conv(a:list ,b:list ,n:int ):
     Fntt(a, n, 1)
     Fntt(b, n, 1)
     for i in range(n):
          a[i] = a[i] * b[i] % P
     Fntt(a, n, -1)
     
def Transfer(a:list ,n:list ):
     t = 0
     for i in range(n):
          a[i] += t
          if(a[i] > 9):
               t = int(a[i] / 10)
               a[i] %= 10
          else:
               t = 0
def Print (a:list, n:int):
     flag = True
     result=""
     for i in range(n-1,-1,-1):
          if(a[i] != 0  and  flag):
               print( a[i],end="")
               result+=str(a[i])
               flag = False
          elif( not flag):
               print(a[i],end="")
               result+=str(a[i])
     #print("")
     return result

'''N = 1<<18
P = (479<<21) + 1
G = 3
NUM = 20'''
N = 1<<8
P = 257
G = 3
NUM = 20
wn=[0]*NUM
GetWn(NUM,wn)
if __name__ == "__main__":
     with open("fntt.txt","r") as f:
          sum=[0,0]
          for line in f.readlines():
               line = line.strip('\n')  #去掉列表中每一个元素的换行符
               nums = line.split()
               A=[]
               B=[]
               nttResult=""
               for i in nums[0] :
                    A.append(int(i))
               for i in nums[1] :
                    B.append(int(i))
               a=[]
               b=[]
               length=Prepare(A, B,a,b,length=1)
               Conv(a,b,length)
               Transfer(a, length)  
               print("A:"+nums[0]+"\t"+"B:"+nums[1]+"\t"+"result:"+nums[2]+"\t"+"FNTT:",end="") 
               nttResult=Print(a,length)
               if nttResult == nums[2]:
                    print("\tTrue")
                    sum[0]+=1
               else:
                    print("\tFalse")
                    sum[1]+=1
     accuracy=100*sum[0]/(sum[0]+sum[1])
     print("True:"+str(sum[0])+"\tFalse:"+str(sum[1])+"\t\t"+"Accuracy:",end="")
     print("%.2f%%" % accuracy)
                
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Greate AUK

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值