FFT什么的

  这里只有公式&做法,没有复杂的证明(其实是因为弱鸡yww不会)

  参考自国家集训队论文&各个博客

多项式

​  一个以 x x 为变量的多项式定义在一个代数域F上,将函数 A(x) A ( x ) 表示为形式和:

A(x)=j=0n1ajxj A ( x ) = ∑ j = 0 n − 1 a j x j

我们称 a0,a1,,an1 a 0 , a 1 , … , a n − 1 为多项式的系数,所有系数都属于数域 F F ,典型的情形是负数集合C

  如果一个多项式的最高次的非零系数是 ak a k ,则称 A(x) A ( x ) 的次数是 k k 。任何严格大于一个多项式次数的整数都是该多项式的次数界。因此,对于次数界为n的多项式 C(x) C ( x ) ,其次数可以是 0 0 ~n1之间的任何整数,包括 0 0 n1

​  我们在多项式上可以定义很多不同的运算。

多项式加法

​  如果 A(x) A ( x ) B(x) B ( x ) 是次数界为 n n 的多项式,那么他们的和也是一个次数界为n的多项式 C(x) C ( x ) 。对于所有属于定义域的 x x ,都有C(x)=A(x)+B(x)。也就是说,若

A(x)=j=0n1ajxjB(x)=j=0n1bjxj A ( x ) = ∑ j = 0 n − 1 a j x j B ( x ) = ∑ j = 0 n − 1 b j x j


C(x)=j=0n1cjxj C ( x ) = ∑ j = 0 n − 1 c j x j

其中
cj=aj+bj c j = a j + b j

​  例如,如果
A(x)=6x3+7x210x+9,B(x)=2x3+4x5 A ( x ) = 6 x 3 + 7 x 2 − 10 x + 9 , B ( x ) = − 2 x 3 + 4 x − 5


C(x)=4x3+7x26x+4 C ( x ) = 4 x 3 + 7 x 2 − 6 x + 4

多项式乘法

​  如果 A(x) A ( x ) 是次数界为 n n 的多项式,B(x)是次数界为 m m 的多项式,那么他们的乘积是一个次数界为n+m的多项式 C(x) C ( x ) 。其中

cj=k=0jakbjk c j = ∑ k = 0 j a k b j − k

​  例如,如果
A(x)=6x3+7x210x+9,B(x)=2x3+4x5 A ( x ) = 6 x 3 + 7 x 2 − 10 x + 9 , B ( x ) = − 2 x 3 + 4 x − 5

​  则
C(x)=12x614x5+44x420x375x2+86x45 C ( x ) = − 12 x 6 − 14 x 5 + 44 x 4 − 20 x 3 − 75 x 2 + 86 x − 45

多项式的表示

系数表达

​  对一个次数界为 n n 的多项式A(x)=j=0n1ajxj而言,其系数表达式一个由系数组成得到向量 a=(a0,a1,,an1) a = ( a 0 , a 1 , ⋯ , a n − 1 )

​  我们可以用秦久韶算法在 O(n) O ( n ) 的时间内求出多项式在给定点 x0 x 0 的值,即求值运算:

A(x0)=a0+x0(a1+a0(a2++x0(an1+x0(an1))) A ( x 0 ) = a 0 + x 0 ( a 1 + a 0 ( a 2 + ⋯ + x 0 ( a n − 1 + x 0 ( a n − 1 ) ⋯ ) )

​  类似的,对于两个分别用系数向量 a=(a0,a1,,an1),b=(b0,b1,,bn1) a = ( a 0 , a 1 , ⋯ , a n − 1 ) , b = ( b 0 , b 1 , ⋯ , b n − 1 ) 表示的多项式进行相加时,所需的时间是 O(n) O ( n ) 。我们只用输出系数向量 c=(c0,c1,,cn1) c = ( c 0 , c 1 , ⋯ , c n − 1 ) ,其中 ci=ai+bi c i = a i + b i

​  现在来考虑两个用系数形式表达的次数界为 n n 的多项式A(x),B(x)的乘法运算,所需要的时间是 O(n2) O ( n 2 ) 。系数向量 c c 也称为输入向量a,b的卷积。 c=ab c = a ⊗ b

点值表达

​  一个次数界为 n n 的多项式的点值表达就是一个有n个点值对所组成的集合。

{(x0,y0),(x1,y1),,(xn1,yn1)} { ( x 0 , y 0 ) , ( x 1 , y 1 ) , ⋯ , ( x n − 1 , y n − 1 ) }

使得对 k=0,1,,n1 k = 0 , 1 , ⋯ , n − 1 ,所有 xk x k 各不相同且 yk=A(xk) y k = A ( x k )

​  一个多项式可以有很多不同的点值表达,因为可以采用 n n 个不同的点构成的集合作为这种表示方法的基。

​  朴素的求值是O(n2)的。

​  求值的逆称为插值。当插值多项式的次数界等于已知的点值对的数目时,插值才是明确的。

​  我们可以在用高斯消元在 O(n3) O ( n 3 ) 内插值,也可以用拉格朗日插值 O(n2) O ( n 2 ) 内插值。

​  以上求值和插值可以将多项式的系数表达和点值表达进行相互转化,上面给出的算法的时间复杂度是 O(n2) O ( n 2 ) ,但我们可以巧妙地选取 xk x k 来加速这一过程,使其运行时间变为 O(nlogn) O ( n l o g n )

​  对于许多多项式相关的操作,点值表达式很便利的。

​  对于加法,如果 C(x)=A(x)+B(x) C ( x ) = A ( x ) + B ( x ) 。给定 A A 的点值表达

{(x0,y0),(x1,y1),,(xn1,yn1)}

B B 的点值表达
{(x0,y0),(x1,y1),,(xn1,yn1)}

(注意, A A B在相同的 n n 个位置求值),则C的点值表达是

{(x0,y0+y0),(x1,y1+y1),,(xn1,yn1+yn1)} { ( x 0 , y 0 + y 0 ′ ) , ( x 1 , y 1 + y 1 ′ ) , ⋯ , ( x n − 1 , y n − 1 + y n − 1 ′ ) }

因此,对两个点值形式表示的次数界为 n n 的多项式相加,时间复杂度是O(n)

​  类似的,如果 C(x)=A(x)B(x) C ( x ) = A ( x ) B ( x ) ,我们需要 2n 2 n 个点值对才能插出 C C 。给定A的点值表达

{(x0,y0),(x1,y1),,(x2n1,y2n1)} { ( x 0 , y 0 ) , ( x 1 , y 1 ) , ⋯ , ( x 2 n − 1 , y 2 n − 1 ) }

B B 的点值表达
{(x0,y0),(x1,y1),,(x2n1,y2n1)}

(注意, A A B在相同的 2n 2 n 个位置求值),则 C C 的点值表达是
{(x0,y0y0),(x1,y1y1),,(x2n1,y2n1y2n1)}

因此,对两个点值形式表示的次数界为 n n 的多项式相乘,时间复杂度是O(n)

​  最后,我们考虑一个采用点值表达的多项式,如何求其在某个新点上的值。最简单的方法是把该多项式转成系数形式表达,然后在新点处求值。

系数形式表示的多项式的快速乘法

​  如果我们选 n n 次单位复数根作为求值点,我们可以在O(nlogn)内求值和插值。我们先在对这两个多项式 A,B A , B 求值之前添加 n n 0,使其次数界加倍为 2n 2 n 。现在我们采用“ 2n 2 n 次单位复数根”作为求值点。

DFT&FFT&IDFT

单位复数根

​   n n 次单位复数根是满足wn=1的复数 w w n次单位复数根恰好有 n n 个,对于k=0,1,,n1,这些根是 e2πikn e 2 π i k n wn=e2πin w n = e 2 π i n 称为主 n n 次单位根,所有其他n次单位复数根都是 wn w n 的幂次。这 n n n次单位复数根在乘法意义下形成了一个群,即 wjnwkn=w(j+k)mod nn w n j w n k = w n ( j + k ) m o d   n ,而且这 n n n次单位复数根均匀分布在以复平面的原点为圆心的单位半径的圆周上。(图片from zjt)

这里写图片描述

​  消去引理:对任何整数 n0,k0,d>0 n ≥ 0 , k ≥ 0 , d > 0

wdkdn=wkn w d n d k = w n k

DFT

​  回顾一下,我们希望计算次数界为 n n 的多项式A(x) w0n,w1n,,wn1n w n 0 , w n 1 , ⋯ , w n n − 1 处的值(即在 n n n次单位复数根处)。对于 k=0,1,,n1 k = 0 , 1 , ⋯ , n − 1 ,定义结果 yk y k

yk=A(wkn)=j=0n1ajwkjn y k = A ( w n k ) = ∑ j = 0 n − 1 a j w n k j

向量 y=(y0,y1,,yn1) y = ( y 0 , y 1 , ⋯ , y n − 1 ) 就是系数向量 a a 的离散傅里叶变换(DFT),我们也记为y=DFTn(a)

FFT

​  利用单位复数根的特殊性质,我们可以在 O(nlogn) O ( n l o g n ) 内计算出 DFTn(a) D F T n ( a ) 。这里假设 n n 2的幂。

  FFT利用了分治策略。

  我们令 a=(a0,a1,,an1),a1=(a0,a2,,an2),a2=(a1,a3,,an1) a = ( a 0 , a 1 , ⋯ , a n − 1 ) , a 1 = ( a 0 , a 2 , ⋯ , a n − 2 ) , a 2 = ( a 1 , a 3 , ⋯ , a n − 1 )

  对于 k<n2 k < n 2 有:

yk=A(wkn)=j=0n1ajwkjn=j=0n21a2jw2kjn+j=0n21a2j+1w2kj+kn=j=0n21a2jw2kjn+wknj=0n21a2j+1w2kjn=j=0n21a1jwkjn2+wknj=0n21a2jwkjn2=y1k+wkny2k(1)(2)(3)(4)(5)(6) (1) y k = A ( w n k ) (2) = ∑ j = 0 n − 1 a j w n k j (3) = ∑ j = 0 n 2 − 1 a 2 j w n 2 k j + ∑ j = 0 n 2 − 1 a 2 j + 1 w n 2 k j + k (4) = ∑ j = 0 n 2 − 1 a 2 j w n 2 k j + w n k ∑ j = 0 n 2 − 1 a 2 j + 1 w n 2 k j (5) = ∑ j = 0 n 2 − 1 a 1 j w n 2 k j + w n k ∑ j = 0 n 2 − 1 a 2 j w n 2 k j (6) = y 1 k + w n k y 2 k

  对于 kn2 k ≥ n 2 有:
yk=A(wkn)=j=0n1ajwkjn=j=0n21a2jw2kjn+j=0n21a2j+1w2kj+kn=j=0n21a2jw2kjn+wknj=0n21a2j+1w2kjn=j=0n21a1jwkjn2+wknj=0n21a2jwkjn2=j=0n21a1jw(kn2)jn2+wknj=0n21a2jw(kn2)jn2=y1kn2+wkny2kn2=y1kn2wkn2ny2kn2(7)(8)(9)(10)(11)(12)(13)(14) (7) y k = A ( w n k ) (8) = ∑ j = 0 n − 1 a j w n k j (9) = ∑ j = 0 n 2 − 1 a 2 j w n 2 k j + ∑ j = 0 n 2 − 1 a 2 j + 1 w n 2 k j + k (10) = ∑ j = 0 n 2 − 1 a 2 j w n 2 k j + w n k ∑ j = 0 n 2 − 1 a 2 j + 1 w n 2 k j (11) = ∑ j = 0 n 2 − 1 a 1 j w n 2 k j + w n k ∑ j = 0 n 2 − 1 a 2 j w n 2 k j (12) = ∑ j = 0 n 2 − 1 a 1 j w n 2 ( k − n 2 ) j + w n k ∑ j = 0 n 2 − 1 a 2 j w n 2 ( k − n 2 ) j (13) = y 1 k − n 2 + w n k y 2 k − n 2 (14) = y 1 k − n 2 − w n k − n 2 y 2 k − n 2

  这样我们把 y1,y2 y 1 , y 2 合并为 y y 的时间复杂度是O(n)。所以总的时间复杂度是
T(n)=2T(n2)+O(n)=O(nlogn) T ( n ) = 2 T ( n 2 ) + O ( n ) = O ( n log ⁡ n )

IDFT

​  通过推导公式,我们得到:

ak=1nj=0n1yjwkjn a k = 1 n ∑ j = 0 n − 1 y j w n − k j

​  所以我们可以用类似FFT的方法在 O(nlogn) O ( n log ⁡ n ) 内求出 IDFTn(y) I D F T n ( y )

多项式乘法

​  我们可以在 O(n) O ( n ) 内补 0 0 O(nlogn)内求值, O(n) O ( n ) 内点值乘法, O(nlogn) O ( n log ⁡ n ) 内插值。所以我们可以在 O(nlogn) O ( n log ⁡ n ) 内求出 ab a ⊗ b

ab=IDFT2n(DFT2n(a)DFT2n(b)) a ⊗ b = I D F T 2 n ( D F T 2 n ( a ) ⋅ D F T 2 n ( b ) )

蝶形运算

  我们把由 y1k,y2k,wkn y 1 k , y 2 k , w n k 得到 yk,yk+n2 y k , y k + n 2 的过程称为蝴蝶操作。

​  我们发现,递归时 a a 是长这样的:

0   1   2   3   4   5   6   70   2   4   6 | 1   3   5   70   4 | 2   6 | 1   5 | 3   70 | 4 | 2 | 6 | 1 | 5 | 3 | 7

  总的蝶形运算是长这样的:
  
  这里写图片描述

​  可以发现,最后 ai a i 是原来的 arev(i) a r e v ( i ) 。所以我们可以交换 ai,arev(i) a i , a r e v ( i ) ,然后一层层来做。这样可以减小常数。

NTT

​  在某些时候,我们需要求模 p p 意义下的卷积。

​  先求出p的原根 g g ,可以发现,gp1n wn w n 的性质类似。所以我们可以用 gp1n g p − 1 n 来代替 wn w n

时间上的优化

​  令 tj=(aj+bj)+(ajbj)i,S=T×T t j = ( a j + b j ) + ( a j − b j ) i , S = T × T

​   sj s j 的实部为

k=0j(ak+bk)2(akbk)2=k=0j4akbk=4k=0jakbk(15) (15) ∑ k = 0 j ( a k + b k ) 2 − ( a k − b k ) 2 = ∑ k = 0 j 4 a k b k = 4 ∑ k = 0 j a k b k

  这样我们就可以求出 S=T×T S = T × T ,然后把 sj s j 除以 4 4

  这个方法可以把3次DFT改成 2 2 次DFT。

多项式求导

  给定A(x)=i0aixi,定义 A(x) A ( x ) 的形式导数为

A(x)=i1iaixi1 A ′ ( x ) = ∑ i ≥ 1 i a i x i − 1

多项式积分

  给定 A(x)=i0aixi A ( x ) = ∑ i ≥ 0 a i x i ,则

A(x)=i1ai1ixi ∫ A ( x ) = ∑ i ≥ 1 a i − 1 i x i

多项式求逆

​  多项式 A(x) A ( x ) 存在乘法逆元的充要条件是 A(x) A ( x ) 的常数项存在乘法逆元。

​  下面介绍一个 O(n log n) O ( n   l o g   n ) 计算乘法逆元的算法,它的本质是牛顿迭代法

​  首先求出 A(x) A ( x ) 常数项的逆元 b b ,令B(x)的初始值为 b b

​  假设已求出满足

A(x)B(x)1 (mod xn)

B(x) B ( x ) ,则

A(x)B(x)1(A(x)B(x)1)2A(x)(2B(x)B(x)2A(x))0 (mod xn)0 (mod x2n)1 (mod x2n)(16)(17)(18) (16) A ( x ) B ( x ) − 1 ≡ 0   ( m o d   x n ) (17) ( A ( x ) B ( x ) − 1 ) 2 ≡ 0   ( m o d   x 2 n ) (18) A ( x ) ( 2 B ( x ) − B ( x ) 2 A ( x ) ) ≡ 1   ( m o d   x 2 n )

​  我们可以用 O(n log n) O ( n   l o g   n ) 的时间计算出 2B(x)B(x)2A(x) 2 B ( x ) − B ( x ) 2 A ( x ) ,并将它赋值给 B(x) B ( x ) 进行下一次迭代。每迭代一次, B(x) B ( x ) 的有效项数 n n 都会增加一倍。于是该算法的时间复杂度为
T(n)=T(n/2)+O(nlogn)=O(nlogn)

多项式开根

  已知 A(x) A ( x ) ,求 B(x) B ( x ) 使得

B(x)2A(x) (mod xn) B ( x ) 2 ≡ A ( x )   ( m o d   x n )

  先求出 A(x) A ( x ) 常数项的平方根 b b (可以用二次剩余的东西来算,但我只会暴力算),令B(x)的初始值为 b b

  假设已求出满足

B(x)2A(x) (mod xn)

B(x) B ( x ) ,则

B(x)2A(x)(B(x)2A(x))2B(x)42B(x)2A(x)+A(x)2B(x)4+2B(x)2A(x)+A(x)2(B(x)2+A(x))2(B(x)2+A(x)2B(x))20 (mod xn)0 (mod x2n)0 (mod x2n)4B(x)2A(x) (mod x2n)(2B(x))2A(x) (mod x2n)A(x) (mod x2n)(19)(20)(21)(22)(23)(24) (19) B ( x ) 2 − A ( x ) ≡ 0   ( m o d   x n ) (20) ( B ( x ) 2 − A ( x ) ) 2 ≡ 0   ( m o d   x 2 n ) (21) B ( x ) 4 − 2 B ( x ) 2 A ( x ) + A ( x ) 2 ≡ 0   ( m o d   x 2 n ) (22) B ( x ) 4 + 2 B ( x ) 2 A ( x ) + A ( x ) 2 ≡ 4 B ( x ) 2 A ( x )   ( m o d   x 2 n ) (23) ( B ( x ) 2 + A ( x ) ) 2 ≡ ( 2 B ( x ) ) 2 A ( x )   ( m o d   x 2 n ) (24) ( B ( x ) 2 + A ( x ) 2 B ( x ) ) 2 ≡ A ( x )   ( m o d   x 2 n )

  我们可以在 O(nlogn) O ( n log ⁡ n ) 内算出 B(x)2+A(x)2B(x)=B(x)2+A(x)2B(x) B ( x ) 2 + A ( x ) 2 B ( x ) = B ( x ) 2 + A ( x ) 2 B ( x ) ,并把它赋值给 B(x) B ( x )

  时间复杂度: O(nlogn) O ( n log ⁡ n )

多项式ln

  给定形式幂级数 A(x)=i1aixi A ( x ) = ∑ i ≥ 1 a i x i ,定义

ln(1A(x))=i1A(x)ii ln ⁡ ( 1 − A ( x ) ) = − ∑ i ≥ 1 A ( x ) i i

  给定多项式 A(x)=1+i1aixi A ( x ) = 1 + ∑ i ≥ 1 a i x i ,令
B(x)=ln(A(x)) B ( x ) = ln ⁡ ( A ( x ) )


B(x)=A(x)A(x) B ′ ( x ) = A ′ ( x ) A ( x )

  只需要求出 A(x) A ( x ) 的乘法逆元,就可以求出 ln(A(x)) ln ⁡ ( A ( x ) )

多项式exp

  给定形式幂级数 A(x)=i1aixi A ( x ) = ∑ i ≥ 1 a i x i ,定义

exp(A(x))=i0A(x)ii! exp ⁡ ( A ( x ) ) = ∑ i ≥ 0 A ( x ) i i !

  令 f(x)=eA(x) f ( x ) = e A ( x ) ,可得到一个关于 f(x) f ( x ) 的方程
g(f(x))=ln(f(x))A(x)=0 g ( f ( x ) ) = ln ⁡ ( f ( x ) ) − A ( x ) = 0

  考虑用牛顿迭代解这一方程。首先 f(x) f ( x ) 的常数项是容易确定的(就是 1 1 )。

  设以求得f(x)的前 n n f0(x),即

f(x)f0(x)   (mod   xn) f ( x ) ≡ f 0 ( x )       ( m o d       x n )

  作泰勒展开得
0=g(f(x))=g(f0(x))+g(f0(x))(f(x)f0(x))     (mod   x2n)(25)(26) (25) 0 = g ( f ( x ) ) (26) = g ( f 0 ( x ) ) + g ′ ( f 0 ( x ) ) ( f ( x ) − f 0 ( x ) )           ( m o d       x 2 n )


f(x)f0(x)g(f0(x))g(f0(x))    (mod   x2n) f ( x ) ≡ f 0 ( x ) − g ( f 0 ( x ) ) g ′ ( f 0 ( x ) )         ( m o d       x 2 n )

  把上面那个式子带入得
f(x)=f0(x)ln(f0(x))A(x)1f0(x)=f0(x)(1ln(f0(x))+A(x))(27)(28) (27) f ( x ) = f 0 ( x ) − ln ⁡ ( f 0 ( x ) ) − A ( x ) 1 f 0 ( x ) (28) = f 0 ( x ) ( 1 − ln ⁡ ( f 0 ( x ) ) + A ( x ) )

  时间复杂度: O(nlogn) O ( n log ⁡ n )
  

多项式求幂

  给你 A(x),k A ( x ) , k ,求 Ak(x) A k ( x )

  设 A(x) A ( x ) 中最低次数项是 cxd c x d ,那么先把整个多项式除以 cxd c x d ,再求 ln ln ,把整个多项式乘以 k k ,再求exp,再乘上 ckxkd c k x k d

Ak(x)=exp(klnA(x)cxd))ckxkd A k ( x ) = exp ⁡ ( k ln ⁡ A ( x ) c x d ) ) c k x k d

  时间复杂度: O(nlogn) O ( n log ⁡ n )

多项式除法

​  给你 A(x),B(x) A ( x ) , B ( x ) ,求两个多项式 D(x),R(x) D ( x ) , R ( x ) 满足

A(x)=D(x)B(x)+R(x) A ( x ) = D ( x ) B ( x ) + R ( x )

​  若 A(x) A ( x ) 是一个 n n 阶多项式,则
AR(x)=xnA(1x)

  举个例子:比如说
A(x)=x3+2x2+3x+4AR(x)=1+2x+3x2+4x3 A ( x ) = x 3 + 2 x 2 + 3 x + 4 A R ( x ) = 1 + 2 x + 3 x 2 + 4 x 3

​  相当于把 A(x) A ( x ) 的系数反转。

  我们设 A(x) A ( x ) n n 阶多项式,B(x) m m 阶多项式,D(x) nm n − m 阶多项式, R(x) R ( x ) m1 m − 1 阶多项式。我们把上个式子的 x x 1x,然后全部乘上 xn x n

xnA(1x)=xnmD(1x)xmB(1x)+xnm+1xm1R(1x)AR(x)=DR(x)BR(x)+xnm+1RR(x) x n A ( 1 x ) = x n − m D ( 1 x ) x m B ( 1 x ) + x n − m + 1 x m − 1 R ( 1 x ) A R ( x ) = D R ( x ) B R ( x ) + x n − m + 1 R R ( x )

  然后我们把这个式子放在模 xnm+1 x n − m + 1 意义下,得到
AR(x)=DR(x)BR(x) (mod xnm+1)DR(x)=AR(x)(BR(x))1 (mod xnm+1) A R ( x ) = D R ( x ) B R ( x )   ( m o d   x n − m + 1 ) D R ( x ) = A R ( x ) ( B R ( x ) ) − 1   ( m o d   x n − m + 1 )

  因为 D(x) D ( x ) 的次数是 nm n − m ,所以不会受模意义的影响。

  然后把 D(x) D ( x ) 带入到原来的式子中,就可以算出 R(x) R ( x ) 了。

  时间复杂度: O(nlogn) O ( n log ⁡ n )

多点求值

  给你一个多项式 A(x) A ( x ) n n 个点x0,x1,,xn1,求这个多项式在这 n n 个点处的值,即求A(x0),A(x1),,A(xn1)

  考虑一个简单的做法:构造 Bi(x)=xxi,Ci(x)=A(x) mod Bi(x) B i ( x ) = x − x i , C i ( x ) = A ( x )   m o d   B i ( x ) ,那么 Bi(xi)=0 B i ( x i ) = 0 。所以 A(xi)=Ci(xi) A ( x i ) = C i ( x i ) 。但是计算 Bi(x) B i ( x ) Ci(x) C i ( x ) O(n) O ( n ) 的,必须加速这个过程。

  设当前求值的点为 X={x0,x1,,xn1} X = { x 0 , x 1 , ⋯ , x n − 1 } ,我们可以把这 n n 个点分为两半:

X0={x0,x1,,xn21}X1={xn2,xn2+1,,xn1}

  构造多项式

B0=i=0n21(xxi)B1=i=n2n1(xxi)A0=A mod B0A1=A mod B1 B 0 = ∏ i = 0 n 2 − 1 ( x − x i ) B 1 = ∏ i = n 2 n − 1 ( x − x i ) A 0 = A   m o d   B 0 A 1 = A   m o d   B 1

  那么当 xX0 x ∈ X 0 A(x)=A0(x) A ( x ) = A 0 ( x ) ,可以递归计算。当 xX1 x ∈ X 1 时同理。

  每一层计算 B0,B1,A0,A1 B 0 , B 1 , A 0 , A 1 的时间复杂度都是 O(nlogn) O ( n log ⁡ n )

  总的时间复杂度就是

T(n)=2T(n2)+O(nlogn)=O(nlog2n) T ( n ) = 2 T ( n 2 ) + O ( n log ⁡ n ) = O ( n log 2 ⁡ n )

快速插值

  考虑怎么求 gi=nj=0,ji(xixj) g i = ∏ j = 0 , j ≠ i n ( x i − x j ) ,也就是分母。

gi=j=0,jin(xixj)=limxxinj=0(xxj)xxi=(j=0n(xxj))|x=xi(29)(30)(31) (29) g i = ∏ j = 0 , j ≠ i n ( x i − x j ) (30) = lim x → x i ∏ j = 0 n ( x − x j ) x − x i (31) = ( ∏ j = 0 n ( x − x j ) ) ′ | x = x i

  可以分治求出 nj=0(xxj) ∏ j = 0 n ( x − x j ) 再求导后在所有 xi x i 处多点求值。

  分子直接分治求出。

  时间复杂度: O(nlog2n) O ( n log 2 ⁡ n )

小技巧1

  比如我们要计算两个实数序列的卷积 A×B=C A × B = C ,记 Di=(ai+bi)+(aibi)i D i = ( a i + b i ) + ( a i − b i ) i ,那么 Ci=14real(D2i) C i = 1 4 r e a l ( D 2 i )
  
  这样就可以把三次DFT减少到两次DFT。
  
  当然,如果 A=B A = B 那么这个优化是没有效果的。

任意模数FFT

模板

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
    if(a>b)
        swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
    char str[100];
    sprintf(str,"%s.in",s);
    freopen(str,"r",stdin);
    sprintf(str,"%s.out",s);
    freopen(str,"w",stdout);
#endif
}
int rd()
{
    int s=0,c;
    while((c=getchar())<'0'||c>'9');
    do
    {
        s=s*10+c-'0';
    }
    while((c=getchar())>='0'&&c<='9');
    return s;
}
int upmin(int &a,int b)
{
    if(b<a)
    {
        a=b;
        return 1;
    }
    return 0;
}
int upmax(int &a,int b)
{
    if(b>a)
    {
        a=b;
        return 1;
    }
    return 0;
}
const ll p=998244353;
const ll g=3;
ll fp(ll a,ll b)
{
    ll s=1;
    while(b)
    {
        if(b&1)
            s=s*a%p;
        a=a*a%p;
        b>>=1;
    }
    return s;
}
const int maxn=600000;
ll inv[maxn];
namespace ntt
{
    ll w1[maxn];
    ll w2[maxn];
    int rev[maxn];
    int n;
    void init(int m)
    {
        n=1;
        while(n<m)
            n<<=1;
        int i;
        for(i=2;i<=n;i<<=1)
        {
            w1[i]=fp(g,(p-1)/i);
            w2[i]=fp(w1[i],p-2);
        }
        rev[0]=0;
        for(i=1;i<n;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
    }
    void ntt(ll *a,int t)
    {
        int i,j,k;
        ll u,v,w,wn;
        for(i=0;i<n;i++)
            if(rev[i]<i)
                swap(a[i],a[rev[i]]);
        for(i=2;i<=n;i<<=1)
        {
            wn=(t==1?w1[i]:w2[i]);
            for(j=0;j<n;j+=i)
            {
                w=1;
                for(k=j;k<j+i/2;k++)
                {
                    u=a[k];
                    v=a[k+i/2]*w%p;
                    a[k]=(u+v)%p;
                    a[k+i/2]=(u-v)%p;
                    w=w*wn%p;
                }
            }
        }
        if(t==-1)
        {
            u=fp(n,p-2);    
            for(i=0;i<n;i++)
                a[i]=a[i]*u%p;
        }
    }
    ll x[maxn];
    ll y[maxn];
    ll z[maxn];
    void copy_clear(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
        for(i=m;i<n;i++)
            a[i]=0;
    }
    void copy(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
    }
    void mul(ll *a,ll *b,ll *c,int m)
    {
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;i<n;i++)
            x[i]=x[i]*y[i]%p;
        ntt(x,-1);
        copy(c,x,m);
    }
    void inverse(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=fp(a[0],p-2);
            return;
        }
        inverse(a,b,m>>1);
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m>>1);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;i<n;i++)
            x[i]=y[i]*(2-x[i]*y[i]%p)%p;
        ntt(x,-1);
        copy(b,x,m);
    }
    ll c[maxn],d[maxn],e[maxn],f[maxn];
    void sqrt(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            if(a[0]==1)
                b[0]=1;
            else if(a[0]==0)
                b[0]=0;
            else
                //我也不会
                ;
            return;
        }
        sqrt(a,b,m>>1);
//      copy_clear(c,b,m>>1);
        int i;
        for(i=m;i<m<<1;i++)
            b[i]=0;
        inverse(b,d,m);
        init(m<<1);
        for(i=m;i<m<<1;i++)
            b[i]=d[i]=0;
        ll inv2=fp(2,p-2);
        copy_clear(x,a,m);
        ntt(x,1);
        ntt(d,1);
        for(i=0;i<n;i++)
            x[i]=x[i]*d[i]%p;
        ntt(x,-1);
        for(i=0;i<m;i++)
            b[i]=((b[i]+x[i])%p*inv2)%p;
    }
    void derivative(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m-1;i++)
            b[i]=(i+1)*a[i+1]%p;
        b[m-1]=0;
    }
    void differential(ll *a,ll *b,int m)
    {
        int i;
        for(i=m-1;i>=1;i--)
            b[i]=a[i-1]*inv[i]%p;
        b[0]=0;
    }
    void ln(ll *a,ll *b,int m)
    {
        static ll c[maxn],d[maxn];
        derivative(a,c,m);
        inverse(a,d,m);
        init(m<<1);
        int i;
        for(i=m;i<n;i++)
            c[i]=d[i]=0;
        ntt(c,1);
        ntt(d,1);
        for(i=0;i<n;i++)
            c[i]=c[i]*d[i]%p;
        ntt(c,-1);
        differential(c,b,m);
    }
    void exp(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=1;
            return;
        }
        exp(a,b,m>>1);
        int i;
        for(i=m>>1;i<m;i++)
            b[i]=0;
        ln(b,y,m);
        init(m<<1);
        copy_clear(x,a,m);
        x[0]++;
        for(i=0;i<m;i++)
            x[i]=(x[i]-y[i])%p;
        copy_clear(y,b,m);
        ntt(x,1);
        ntt(y,1);
        for(i=0;i<n;i++)
            x[i]=x[i]*y[i]%p;
        ntt(x,-1);
        copy(b,x,m);
    }
    void module(ll *a,ll *b,ll *c,int n1,int n2)
    {
        int k=1;
        while(k<=n1-n2+1)
            k<<=1;
        int i;
        for(i=0;i<=n1;i++)
            d[i]=a[i];
        for(i=0;i<=n2;i++)
            e[i]=b[i];
        reverse(d,d+n1+1);
        reverse(e,e+n2+1);
        for(i=n1-n2+1;i<k<<1;i++)
            d[i]=e[i]=0;
        inverse(e,f,k);
        for(i=n1-n2+1;i<k<<1;i++)
            f[i]=0;
        init(k<<1);
        ntt::ntt(d,1);
        ntt::ntt(f,1);
        for(i=0;i<n;i++)
            e[i]=d[i]*f[i]%p;
        ntt::ntt(e,-1);
        for(i=0;i<=n1-n2;i++)
            c[i]=e[i];
        reverse(c,c+n1-n2+1);
    }
};
ll b[maxn];
ll a[maxn];
ll c[maxn];
void get(ll *a,int n)
{
    int i;
    for(i=0;i<n;i++)
        a[i]=rand();
}
int main()
{
//  freopen("fft.txt","w",stdout);
//  srand(time(0));
//  int n=262144;
//  int bg,ed;
//  int i;
//  int times=100,j;
//  double s,s1;
//  inv[0]=inv[1]=1;
//  for(i=2;i<=n;i++)
//      inv[i]=-(p/i)*inv[p%i]%p;
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      bg=clock();
//      ntt::init(n);
//      ntt::ntt(a,1);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("ntt :%.10lf\n",s/times);
//  s1=s;
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      get(b,n);
//      bg=clock();
//      ntt::mul(a,b,c,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("mul :%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      bg=clock();
//      ntt::inverse(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("inv :%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      a[0]=1;
//      bg=clock();
//      ntt::sqrt(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("sqrt:%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      a[0]=1;
//      bg=clock();
//      ntt::ln(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("ln  :%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      bg=clock();
//      ntt::exp(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("exp :%.10lf %.10lf\n",s/times,s/s1);
//  return 0;
}

多点求值+快速插值

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const ll p=998244353;
const ll g=3;
const int maxw=262144;
const int maxn=270000;
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
int rt,cnt,ls[1000010],rs[1000010];
ll vx[100010],vy[100010],va[100010];
ll inv[maxn],w1[maxn],w2[maxn];
int rev[maxn];
void init()
{
    inv[0]=inv[1]=1;
    for(int i=2;i<=maxw;i++)
        inv[i]=-p/i*inv[p%i]%p;
    for(int i=2;i<=maxw;i<<=1)
    {
        w1[i]=fp(g,(p-1)/i);
        w2[i]=fp(w1[i],p-2);
    }
}
ll *f[1000010];
int len[maxn];
void clear(ll *a,int n)
{
    memset(a,0,(sizeof a[0])*n);
}
void ntt(ll *a,int n,int t)
{
    for(int i=1;i<n;i++)
    {
        rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
        if(i>rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(int i=2;i<=n;i<<=1)
    {
        ll wn=(t==1?w1[i]:w2[i]);
        for(int j=0;j<n;j+=i)
        {
            ll w=1;
            for(int k=j;k<j+i/2;k++)
            {
                ll u=a[k];
                ll v=a[k+i/2]*w%p;
                a[k]=(u+v)%p;
                a[k+i/2]=(u-v)%p;
                w=w*wn%p;
            }
        }
    }
    if(t==-1)
    {
        ll inv=fp(n,p-2);
        for(int i=0;i<n;i++)
            a[i]=a[i]*inv%p;
    }
}
void mul(ll *a,ll *b,ll *c,int n,int m)
{
    int k=1;
    while(k<=n+m)
        k<<=1;
    static ll a1[maxn],a2[maxn];
    clear(a1,k);
    clear(a2,k);
    for(int i=0;i<=n;i++)
        a1[i]=a[i];
    for(int i=0;i<=m;i++)
        a2[i]=b[i];
    ntt(a1,k,1);
    ntt(a2,k,1);
    for(int i=0;i<k;i++)
        a1[i]=a1[i]*a2[i]%p;
    ntt(a1,k,-1);
    for(int i=0;i<=n+m;i++)
        c[i]=a1[i];
}
void getinv(ll *a,ll *b,int n)
{
    if(n==1)
    {
        b[0]=fp(a[0],p-2);
        return;
    }
    getinv(a,b,n>>1);
    static ll a1[maxn],a2[maxn];
    clear(a1,n<<1);
    clear(a2,n<<1);
    for(int i=0;i<n;i++)
        a1[i]=a[i];
    for(int i=0;i<n>>1;i++)
        a2[i]=b[i];
    ntt(a1,n<<1,1);
    ntt(a2,n<<1,1);
    for(int i=0;i<n<<1;i++)
        a1[i]=a2[i]*(2-a2[i]*a1[i]%p)%p;
    ntt(a1,n<<1,-1);
    for(int i=0;i<n;i++)
        b[i]=a1[i];
}
void div(ll *a,ll *b,ll *c,int n,int m)
{
    static ll a1[maxn],a2[maxn],a3[maxn];
    int k=1;
    while(k<=2*(n-m))
        k<<=1;
    for(int i=0;i<=n;i++)
        a1[i]=a[i];
    for(int i=0;i<=m;i++)
        a2[i]=b[i];
    reverse(a1,a1+n+1);
    reverse(a2,a2+m+1);
    clear(a1+n-m+1,k-(n-m+1));
    clear(a2+n-m+1,k-(n-m+1));
    getinv(a2,a3,k);
    clear(a3+n-m+1,k-(n-m+1));
    ntt(a1,k,1);
    ntt(a3,k,1);
    for(int i=0;i<k;i++)
        a1[i]=a1[i]*a3[i]%p;
    ntt(a1,k,-1);
    for(int i=0;i<=n-m;i++)
        c[i]=a1[i];
    reverse(c,c+n-m+1);
}
void getmod(ll *a,ll *b,ll *c,int n,int m)
{
    static ll a1[maxn],a2[maxn];
    int k=1;
    while(k<=n)
        k<<=1;
    clear(a1,k);
    clear(a2,k);
    for(int i=0;i<=m;i++)
        a1[i]=b[i];
    div(a,b,a2,n,m);
    ntt(a1,k,1);
    ntt(a2,k,1);
    for(int i=0;i<k;i++)
        a1[i]=a1[i]*a2[i]%p;
    ntt(a1,k,-1);
    for(int i=0;i<m;i++)
        c[i]=(a[i]-a1[i])%p;
}
void device(int l,int r,int &now)
{
    now=++cnt;
    len[now]=r-l+1;
    f[now]=new ll[len[now]+1];
    if(l==r)
    {
        f[now][1]=1;
        f[now][0]=-vx[l];
        return;
    }
    int mid=(l+r)>>1;
    device(l,mid,ls[now]);
    device(mid+1,r,rs[now]);
    mul(f[ls[now]],f[rs[now]],f[now],len[ls[now]],len[rs[now]]);
}
void getv(ll *a,int n,int l,int r,int now)
{
    ll *a1=new ll[len[now]];
    getmod(a,f[now],a1,n,len[now]);
    if(l==r)
    {
        va[l]=a1[0];
        return;
    }
    int mid=(l+r)>>1;
    getv(a1,len[now]-1,l,mid,ls[now]);
    getv(a1,len[now]-1,mid+1,r,rs[now]);
}
ll *s[1000010];
void getpoly(int l,int r,int now)
{
    s[now]=new ll[len[now]];
    if(l==r)
    {
        s[now][0]=va[l];
        return;
    }
    int mid=(l+r)>>1;
    getpoly(l,mid,ls[now]);
    getpoly(mid+1,r,rs[now]);
    int k=1;
    while(k<=len[now])
        k<<=1;
    static ll a1[maxn],a2[maxn],a3[maxn],a4[maxn];
    clear(a1,k);
    clear(a2,k);
    clear(a3,k);
    clear(a4,k);
    for(int i=0;i<len[ls[now]];i++)
        a1[i]=s[ls[now]][i];
    for(int i=0;i<=len[rs[now]];i++)
        a2[i]=f[rs[now]][i];
    for(int i=0;i<len[rs[now]];i++)
        a3[i]=s[rs[now]][i];
    for(int i=0;i<=len[ls[now]];i++)
        a4[i]=f[ls[now]][i];
    ntt(a1,k,1);
    ntt(a2,k,1);
    ntt(a3,k,1);
    ntt(a4,k,1);
    for(int i=0;i<k;i++)
        a1[i]=(a1[i]*a2[i]+a3[i]*a4[i])%p;
    ntt(a1,k,-1);
    for(int i=0;i<len[now];i++)
        s[now][i]=a1[i];
}
int n;
ll a[maxn],b[maxn],c[maxn];
int main()
{
    init();
    scanf("%d",&n);
    for(int i=0;i<=n;i++)
        scanf("%lld%lld",&vx[i],&vy[i]);
    device(0,n,rt);
    for(int i=0;i<=n;i++)
        a[i]=f[rt][i+1]*(i+1)%p;
    getv(a,n,0,n,rt);
//  for(int i=0;i<=n;i++)
//      printf("%lld ",(va[i]+p)%p);
//  printf("\n");
    for(int i=0;i<=n;i++)
        va[i]=fp(va[i],p-2)*vy[i]%p;
    getpoly(0,n,rt);
    for(int i=0;i<=n;i++)
        printf("%lld ",(s[rt][i]+p)%p);
    printf("\n");
    return 0;
}
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值