下一篇:https://blog.csdn.net/wxkhturfun/article/details/111186205
1.参考链接
https://www.cnblogs.com/zarth/p/7288456.html
参考链接中给出的是c++代码,为了从头到尾理解一遍,在下用Python实现了一遍
需要注意的有:
- 整体代码与FFT(快速傅里叶变换)基本相仿
- 求逆的时候,即
Fntt(a, n, -1)
中的第三个参数为-1时,我们是乘上一个参数Inv,原理如下图:
上图戴自:https://zhuanlan.zhihu.com/p/80297169 - 原作者用C字符串来处理输入,这里用Python列表来处理,差别在于最后一个字符’\0’,不过这没有影响,虽说一度纠结于此而影响改写
- 理论上python3的int型是无限长的(内存足够),但是还是得使用quick_mod函数,否则还是报错:溢出
- 如要想弄清所有的具体的蝶形运算原理、各个步骤的实现意义:https://zhuanlan.zhihu.com/p/80297169
https://blog.csdn.net/GGN_2015/article/details/68922404
6.阅读前,先阅读:链接…………从FTT到NTT
2.Python3 代码(魔改自C++)
给出与源码不同的另一组N P G NUM:
N = 256
P = 7681
G = 17
NUM = 9
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:ys
def quick_mod(a: int, b: int,m: int):
ans = 1
a %= m
while(b):
if(b & 1):
ans = ans * a % m
b-=1
b >>= 1
a = a * a % m
return ans
def GetWn(NUM:int ,wn:list):
for i in range(NUM):
t = 1 << i
wn[i] = quick_mod(G, int((P - 1) / t), P)
def Prepare(A:list, B:list, a:list ,b:list, length=1 ):
length_A = len(A)
length_B = len(B)
while((length <= 2 * length_A) or (length <= 2 * length_B)):
length <<= 1
A.extend( [0]*(length - length_A) )
B.extend( [0]*(length - length_B) )
a.extend([0]*length)
b.extend([0]*length)
for i in range(length_A):
A[length - 1 - i] = A[length_A - 1 - i]
for i in range(length - length_A):
A[i] = 0
for i in range(length_B):
B[length - 1 - i] = B[length_B - 1 - i]
for i in range(length-length_B):
B[i] = 0
for i in range(length):
a[length - 1 - i] = A[i]
for i in range(length):
b[length - 1 - i] = B[i]
return length
def Rader(a:list, length:int):
j = length >> 1
for i in range(1,length-1):
if(i < j):
a[i], a[j] = a[j], a[i]
k = length >> 1
while(j >= k):
j -= k
k >>= 1
if(j < k):
j += k
#return a
def Fntt(a:list ,length:int ,on:int ):
#a=Rader(a, length)
Rader(a, length)
id = 0
h = 2
while (h <= length):
id += 1
for j in range(0,length,h):
w = 1
for k in range(j,j+int(h/2)):
u = a[k] % P
t = w * (a[k + int(h / 2)] % P) % P
a[k] = (u + t) % P
a[k + int(h / 2)] = ((u - t) % P + P) % P
w = w * wn[id] % P
h <<= 1
if(on == -1):
for i in range(1,int(length/2)):
a[i], a[length - i] = a[length - i],a[i]
Inv = quick_mod(length, P - 2, P)
for i in range(length):
a[i] = a[i] % P * Inv % P
def Conv(a:list ,b:list ,n:int ):
Fntt(a, n, 1)
Fntt(b, n, 1)
for i in range(n):
a[i] = a[i] * b[i] % P
Fntt(a, n, -1)
def Transfer(a:list ,n:list ):
t = 0
for i in range(n):
a[i] += t
if(a[i] > 9):
t = int(a[i] / 10)
a[i] %= 10
else:
t = 0
def Print (a:list, n:int):
flag = True
for i in range(n-1,-1,-1):
if(a[i] != 0 and flag):
print( a[i],end="")
flag = False
elif( not flag):
print(a[i],end="")
print("")#这里的print是用于换行的需要,因为前面都是end=""
N = 1<<18
P = (479<<21) + 1
G = 3
NUM = 20
wn=[0]*NUM
GetWn(NUM,wn)
if __name__ == "__main__":
nums = input("input:").split()
A=[]
B=[]
for i in nums[0] :
A.append(int(i))
for i in nums[1] :
B.append(int(i))
a=[]
b=[]
length=Prepare(A, B,a,b,length=1)
Conv(a,b,length)
Transfer(a, length)
Print(a,length)
3.简单测试
- 生成测试文件
import random
with open("fntt.txt","w") as f:
for i in range(10000):
A=random.randint(1000,9999)
B=random.randint(1000,9999)
result=A*B
f.write(str(A)+"\t"+str(B)+"\t"+str(result)+"\n")
w代表覆盖,将“w”改为“a”则代表追加
- 测试
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:ys
def quick_mod(a: int, b: int,m: int):
ans = 1
a %= m
while(b):
if(b & 1):
ans = ans * a % m
b-=1
b >>= 1
a = a * a % m
return ans
def GetWn(NUM:int ,wn:list):
for i in range(NUM):
t = 1 << i
wn[i] = quick_mod(G, int((P - 1) / t), P)
def Prepare(A:list, B:list, a:list ,b:list, length=1 ):
length_A = len(A)
length_B = len(B)
while((length <= 2 * length_A) or (length <= 2 * length_B)):
length <<= 1
A.extend( [0]*(length - length_A) )
B.extend( [0]*(length - length_B) )
a.extend([0]*length)
b.extend([0]*length)
for i in range(length_A):
A[length - 1 - i] = A[length_A - 1 - i]
for i in range(length - length_A):
A[i] = 0
for i in range(length_B):
B[length - 1 - i] = B[length_B - 1 - i]
for i in range(length-length_B):
B[i] = 0
for i in range(length):
a[length - 1 - i] = A[i]
for i in range(length):
b[length - 1 - i] = B[i]
return length
def Rader(a:list, length:int):
j = length >> 1
for i in range(1,length-1):
if(i < j):
a[i], a[j] = a[j], a[i]
k = length >> 1
while(j >= k):
j -= k
k >>= 1
if(j < k):
j += k
#return a
def Fntt(a:list ,length:int ,on:int ):
#a=Rader(a, length)
Rader(a, length)
id = 0
h = 2
while (h <= length):
id += 1
for j in range(0,length,h):
w = 1
for k in range(j,j+int(h/2)):
u = a[k] % P
t = w * (a[k + int(h / 2)] % P) % P
a[k] = (u + t) % P
a[k + int(h / 2)] = ((u - t) % P + P) % P
w = w * wn[id] % P
h <<= 1
if(on == -1):
for i in range(1,int(length/2)):
a[i], a[length - i] = a[length - i],a[i]
Inv = quick_mod(length, P - 2, P)
for i in range(length):
a[i] = a[i] % P * Inv % P
def Conv(a:list ,b:list ,n:int ):
Fntt(a, n, 1)
Fntt(b, n, 1)
for i in range(n):
a[i] = a[i] * b[i] % P
Fntt(a, n, -1)
def Transfer(a:list ,n:list ):
t = 0
for i in range(n):
a[i] += t
if(a[i] > 9):
t = int(a[i] / 10)
a[i] %= 10
else:
t = 0
def Print (a:list, n:int):
flag = True
result=""
for i in range(n-1,-1,-1):
if(a[i] != 0 and flag):
print( a[i],end="")
result+=str(a[i])
flag = False
elif( not flag):
print(a[i],end="")
result+=str(a[i])
#print("")
return result
'''N = 1<<18
P = (479<<21) + 1
G = 3
NUM = 20'''
N = 1<<8
P = 257
G = 3
NUM = 20
wn=[0]*NUM
GetWn(NUM,wn)
if __name__ == "__main__":
with open("fntt.txt","r") as f:
sum=[0,0]
for line in f.readlines():
line = line.strip('\n') #去掉列表中每一个元素的换行符
nums = line.split()
A=[]
B=[]
nttResult=""
for i in nums[0] :
A.append(int(i))
for i in nums[1] :
B.append(int(i))
a=[]
b=[]
length=Prepare(A, B,a,b,length=1)
Conv(a,b,length)
Transfer(a, length)
print("A:"+nums[0]+"\t"+"B:"+nums[1]+"\t"+"result:"+nums[2]+"\t"+"FNTT:",end="")
nttResult=Print(a,length)
if nttResult == nums[2]:
print("\tTrue")
sum[0]+=1
else:
print("\tFalse")
sum[1]+=1
accuracy=100*sum[0]/(sum[0]+sum[1])
print("True:"+str(sum[0])+"\tFalse:"+str(sum[1])+"\t\t"+"Accuracy:",end="")
print("%.2f%%" % accuracy)