NTT数论变换(三)

上一篇:https://blog.csdn.net/wxkhturfun/article/details/111186205
下一篇:https://blog.csdn.net/wxkhturfun/article/details/114596494

前言

前面两篇分析的都是“库里-图基”类型的蝶形运算,蝶形运算操作是先乘后加:

u = a[k] % P
t = w * (a[k + int(h / 2)] % P) % P
a[k] = (u + t) % P
a[k + int(h / 2)] = ((u - t) % P + P) % P

还存在另一种蝶形运算:“桑德-图基”,是先加后乘:

u = (a[k] + a[k + int(h / 2)])%P
t = w * ((a[k]-a[k + int(h / 2)])%P+P)%P
a[k] = u
a[k + int(h / 2)] = t

两种操作复杂度一致,但是考虑硬件开销,桑德-图基只用了一次乘法,而库里-图基用了两次,所以决定还将其实现一遍。

1.变化

桑德-图基相当于库里-图基的“逆”,也就是说整个系统的蝶形运算的输入改为输出,各参数的顺序不变,常量的值不变,箭头方向相反。
所以我们只需要改变两个地方:一个是wn的计算,还有就是Fntt()
首先给出原来的库里-图基:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:ys
import random
def quick_mod(a: int, b: int,m: int):
     ans = 1
     a %= m
     while(b):
          if(b & 1):
               ans = ans * a % m
               b-=1
          b >>= 1
          a = a * a % m
     return ans
def GetWn(NUM:int ,wn:list):
    for i in range(NUM):
        t = 1 << i
        wn[i] = quick_mod(G, int((P - 1) / t), P)
def Prepare(A:list, B:list, a:list ,b:list, length=1 ):
     length_A = len(A)
     length_B = len(B)
     while((length <= 2 * length_A) or (length <= 2 * length_B)):
          length <<= 1
     A.extend( [0]*(length - length_A) ) 
     B.extend( [0]*(length - length_B) ) 
     a.extend([0]*length)
     b.extend([0]*length)
     for i in range(length_A):
          A[length - 1 - i] = A[length_A - 1 - i]
     for i in range(length - length_A):
        A[i] = 0
     for i in range(length_B):
          B[length - 1 - i] = B[length_B - 1 - i]
     for i in range(length-length_B):
        B[i] = 0
     for i in range(length):
          a[length - 1 - i] = A[i]
     for i in range(length):
          b[length - 1 - i] = B[i]
     return length
def  Rader(a:list, length:int):
     j = length >> 1
     for i in range(1,length-1):
          if(i < j): 
               a[i], a[j] = a[j], a[i] 
          k = length >> 1
          while(j >= k):
               j -= k
               k >>= 1
          if(j < k):
               j += k
     #return a
def Fntt(a:list ,length:int ,on:int ):

     #a=Rader(a, length)
     Rader(a, length)
     id = 0
     h = 2
     while (h <= length):
          print(h) 
          id += 1
          for j in range(0,length,h):
               w = 1
               for k in range(j,j+int(h/2)):
                    u = a[k] % P
                    t = w * (a[k + int(h / 2)] % P) % P
                    a[k] = (u + t) % P
                    a[k + int(h / 2)] = ((u - t) % P + P) % P
                    w = w * wn[id] % P 
          h <<= 1             
     if(on == -1):
          for i in range(1,int(length/2)):
               a[i], a[length - i] = a[length - i],a[i]
          Inv = quick_mod(length, P - 2, P)
          for i in range(length):
               a[i] = a[i] % P * Inv % P
 
def Conv(a:list ,b:list ,n:int ):
     Fntt(a, n, 1)
     Fntt(b, n, 1)
     for i in range(n):
          a[i] = a[i] * b[i] % P
     Fntt(a, n, -1)
     
def Transfer(a:list ,n:list ):
     t = 0
     for i in range(n):
          a[i] += t
          if(a[i] > 9):
               t = int(a[i] / 10)
               a[i] %= 10
          else:
               t = 0
def Print (a:list, n:int):
     flag = True
     for i in range(n-1,-1,-1):
          if(a[i] != 0  and  flag):
               print( a[i],end="")
               flag = False
          elif( not flag):
               print(a[i],end="")
     print("")#这里的print是用于换行的需要,因为前面都是end=""


N = 256
P = 7681
G = 17
NUM = 13
wn=[0]*NUM
GetWn(NUM,wn)
if __name__ == "__main__":
     a=[]
     b=[]
     for i in range (256):
          #a.append(i)
          #b.append(i)
          a.append(random.randint(0,256))
          b.append(random.randint(0,256))
     length=256
     with open("a.txt",'w')as f:
          for i in a:
               '''i=str(hex(i))
               i=i[2:]
               if len(i)<2:
                    i='0'*(2-len(i))+i
                    '''
               f.write(str(i))
               f.write("\n")
          f.close()
     with open("b.txt",'w')as f:
          for i in b:
               '''i=str(hex(i))
               i=i[2:]
               if len(i)<2:
                    i='0'*(2-len(i))+i'''
               f.write(str(i))
               f.write("\n")
          f.close()
     Conv(a,b,256)
     with open("mul.txt",'w')as f:
          for i in a:
               i=str(hex(i))
               i=i[2:]
               if len(i)<4:
                    i='0'*(4-len(i))+i
               f.write(i)
               f.write("\n")

下面给出桑德-图基:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:ys
import random
def quick_mod(a: int, b: int,m: int):
    ans = 1
    a %= m
    while(b):
        if(b & 1):
            ans = ans * a % m
            b-=1
        b >>= 1
        a = a * a % m
    return ans
def GetWn(NUM:int ,wn:list):
    for i in range(NUM):
        t = 1 << (NUM-1-i)
        wn[i] = quick_mod(G, int((P - 1) / t), P)
def Prepare(A:list, B:list, a:list ,b:list, length=1 ):
    length_A = len(A)
    length_B = len(B)
    while((length <= 2 * length_A) or (length <= 2 * length_B)):
        length <<= 1
    A.extend( [0]*(length - length_A) ) 
    B.extend( [0]*(length - length_B) ) 
    a.extend([0]*length)
    b.extend([0]*length)
    for i in range(length_A):
        A[length - 1 - i] = A[length_A - 1 - i]
    for i in range(length - length_A):
        A[i] = 0
    for i in range(length_B):
        B[length - 1 - i] = B[length_B - 1 - i]
    for i in range(length-length_B):
        B[i] = 0
    for i in range(length):
        a[length - 1 - i] = A[i]
    for i in range(length):
        b[length - 1 - i] = B[i]
    return length
def  Rader(a:list, length:int):
    j = length >> 1
    for i in range(1,length-1):
        if(i < j): 
            a[i], a[j] = a[j], a[i] 
        k = length >> 1
        while(j >= k):
            j -= k
            k >>= 1
        if(j < k):
            j += k
    #return a
def Fntt(a:list ,length:int ,on:int ):
    id = 0
    h = length
    while(h>=2):
        '''for j in range(length-h,-h,-h):
            w = 1
            for k in range(j,j+int(h/2)):
                '''
        for j in range(0,length,h):
            w = 1
            for k in range(j,j+int(h/2)):
                u = (a[k] + a[k + int(h / 2)])%P
                t = w * ((a[k]-a[k + int(h / 2)])%P+P)%P
                a[k] = u
                a[k + int(h / 2)] = t
                w = w * wn[id] % P 
        id+=1
        h>>=1  
    Rader(a,length)          
    if(on == -1):
        for i in range(1,int(length/2)):
            a[i], a[length - i] = a[length - i],a[i]
        Inv = quick_mod(length, P - 2, P)
        for i in range(length):
            a[i] = a[i] % P * Inv % P

def Conv(a:list ,b:list ,n:int ):
    Fntt(a, n, 1)
    Fntt(b, n, 1)
    for i in range(n):
        a[i] = a[i] * b[i] % P
    Fntt(a, n, -1)
    
def Transfer(a:list ,n:list ):
    t = 0
    for i in range(n):
        a[i] += t
        if(a[i] > 9):
            t = int(a[i] / 10)
            a[i] %= 10
        else:
            t = 0
def Print (a:list, n:int):
    flag = True
    for i in range(n-1,-1,-1):
        if(a[i] != 0  and  flag):
            print( a[i],end="")
            flag = False
        elif( not flag):
            print(a[i],end="")
    print("")#这里的print是用于换行的需要,因为前面都是end=""


N = 256
P = 7681
G = 17
NUM = 9
wn=[0]*NUM
GetWn(NUM,wn)
if __name__ == "__main__":
    length=256
    a=[]
    b=[]
    with open("a.txt","r") as f:
        for line in f.readlines():
            line = line.strip('\n')  #去掉列表中每一个元素的换行符
            nums = line.split()
            for i in nums:
                a.append(int(i))
    with open("b.txt","r") as f:
        for line in f.readlines():
            line = line.strip('\n')  #去掉列表中每一个元素的换行符
            nums = line.split()
            for i in nums:
                b.append(int(i))
    Conv(a,b,256)
    with open("GS.txt",'w')as f:
        for i in a:
            i=str(hex(i))
            i=i[2:]
            if len(i)<4:
                    i='0'*(4-len(i))+i
            f.write(i)
            f.write("\n")

上面两个代码是相辅相成的,先执行第一个,再执行第二个,第二个输入与第一个保持一致,然后对比两者的输出:diff -u GS.txt mul.txt,如果终端没有任何打印,说明两者完全相同。
另外,有一点需要注意,代码给出的是256序列长度的数相乘,也就是8级蝶形运算,当不是256长时,库里-图基给出的代码不变,但是桑德-图基给出的代码中的wn的顺序要变!!! 两者的wn要在X级下保证完全颠倒(X是2的次幂,例子中X=8)
改变:wn的顺序颠倒、id从0开始计数。Fntt内部由先码位倒置(Rader)再计算,变为先计算,再进行码位倒置。其他还有些微不足道的改变,不予赘述。

关于两种蝶形运算的具体操作,可参见:《数字信号处理》清华大学出版社 程佩青 第四版

二.硬计算(不用蝶形运算)

def quick_mod(a: int, b: int,m: int):
     ans = 1
     a %= m
     while(b):
          if(b & 1):
               ans = ans * a % m
               b-=1
          b >>= 1
          a = a * a % m
     return ans
def originNTT(a:list):
     g=quick_mod(G, int((P - 1)/N), P)
     re=[0]*N
     for i in range(N):
          for j in range(N):
               fuck=quick_mod(g, int(i*j), P)
               re[i]+= a[j]*fuck
               re[i] = re[i]%7681
     return re
G=17
P=7681
N=256
if __name__=="__main__":
     a=[]
     with open("a.txt","r") as f:
          for line in f.readlines():
               line = line.strip('\n')  #去掉列表中每一个元素的换行符
               nums = line.split()
               for i in nums:
                    a.append(int(i))
     result=originNTT(a)
     with open("originNTT.txt",'w')as f:
        for i in result:
            i=str(hex(i))
            i=i[2:]
            if len(i)<4:
                    i='0'*(4-len(i))+i
            f.write(str(i))
            f.write("\n")
diff originNTT.txt GS.txt

结果相同

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Greate AUK

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值