多项式合集

多项式合集

拉格朗日插值

问题背景

给出 n n n 个点 ( x i , y i ) (x_i,y_i) (xi,yi),令这 n n n 个点确定的多项式为 L ( x ) L(x) L(x),求 L ( k )   m o d   998244353 L(k)\bmod 998244353 L(k)mod998244353 的值。

结论

L ( x ) = ∑ i = 1 n y i l i ( x ) L(x) = \sum_{i=1}^n y_il_i(x) L(x)=i=1nyili(x)

其中每个 l i ( x ) l_i(x) li(x) 为拉格朗日基本多项式,表达式为

l i ( x ) = ∏ j = 1 , j ≠ i n x − x j x i − x j l_i(x) = \prod_{j=1,j\ne i}^n\frac{x-x_j}{x_i-x_j} li(x)=j=1,j=inxixjxxj

其特点是 l i ( x i ) = 1 l_i(x_i)=1 li(xi)=1 ∀ j ≠ i \forall j\ne i j=i l i ( x j ) = 0 l_i(x_j)=0 li(xj)=0

推导

抛开拉插,这道题明显可以列方程组然后使用高斯消元求解,但是复杂度为 O ( n 3 ) O(n^3) O(n3) 且精度问题明显,所以拉格朗日是这样考虑的:

对于每个点 P i ( x i , y i ) P_i(x_i,y_i) Pi(xi,yi),构造一个 n − 1 n-1 n1 次多项式 l i ( x ) l_i(x) li(x) 使其在 x i x_i xi 上取值为 1 1 1,在其余 x j x_j xj 上为 0 0 0。构造的结果就是上面的结论:

l i ( x ) = ∏ j = 1 , j ≠ i n x − x j x i − x j l_i(x) = \prod_{j=1,j\ne i}^n\frac{x-x_j}{x_i-x_j} li(x)=j=1,j=inxixjxxj

这个多项式的正确性还是很显然的。然后我们也知道这个多项式它就是唯一的。

然后考虑构造答案:很显然对于点 P i ( x i , y i ) P_i(x_i,y_i) Pi(xi,yi),只有 l i ( x i ) l_i(x_i) li(xi) 的取值为 1 1 1,其他的都为 0 0 0。所以答案的正确性也是比较显然的:对于 x i x_i xi,只有 y i l i ( x i ) y_il_i(x_i) yili(xi) 产生了贡献,其余的都是 0 0 0。故这个多项式是正确的。

所以回到一开始,我们需要的就是

f ( k ) = ∑ i = 1 n y i ∏ j = 1 , j ≠ i n k − x j x i − x j f(k) = \sum_{i=1}^n y_i\prod_{j=1,j\ne i}^n\frac{k-x_j}{x_i-x_j} f(k)=i=1nyij=1,j=inxixjkxj

由于模数是质数,所以使用费马小定理求逆元,跑得飞快。

复杂度 O ( n 2 ) O(n^2) O(n2),求逆元就是个很小的常数

#include <cstdio>
#include <cctype>
#define il inline

typedef long long ll;

inline ll read()
{
    char c = getchar();
    ll s = 0;
    bool x = 0;
    while (!isdigit(c))
        x = x | (c == '-'), c = getchar();
    while (isdigit(c))
        s = 10 * s + c - '0', c = getchar();
    return x ? -s : s;
}

const ll maxn = 2e3 + 5, mod = 998244353;
ll x[maxn], y[maxn];

ll pow(ll base, ll p)
{
    ll ans = 1;
    base = (base + mod) % mod;
    for (; p; p >>= 1)
    {
        if (p & 1)
            ans = ans * base % mod;
        base = base * base % mod;
    }
    return ans;
}

il ll inv(ll n)
{
    return pow(n, mod - 2);
}

int main()
{
    ll n = read(), k = read();
    for (int i = 1 ; i <= n; ++i)
        x[i] = read(), y[i] = read();
    ll ans = 0;
    for (int i = 1; i <= n; ++i)
    {
        ll prod1 = 1, prod2 = 1;
        for (int j = 1; j <= n; ++j)
        {
            if (i == j)
                continue;
            prod1 = prod1 * (k - x[j]) % mod;
            prod2 = prod2 * (x[i] - x[j]) % mod;
        }
        ans = (ans + prod1 * y[i] % mod * inv(prod2) % mod + mod) % mod;
    }
    printf("%lld\n", ans);
    return 0;
}

拉格朗日插值与范德蒙矩阵

可以考虑将这 n + 1 n+1 n+1 个点值表示为如下形式:

[ x 0 0 x 0 1 x 0 2 ⋯ x 0 n x 1 0 x 1 1 x 1 2 ⋯ x 1 n ⋮ ⋮ ⋮ ⋮ x n 0 x n 1 x n 2 ⋯ x n n ] [ a 0 a 1 ⋮ a n ] = [ y 0 y 1 ⋮ y n ] \begin{bmatrix} x_0^0 & x_0^1 & x_0^2 &\cdots &x_0^n\\ x_1^0 & x_1^1 & x_1^2 &\cdots & x_1^n\\ \vdots & \vdots & \vdots & &\vdots\\ x_n^0 & x_n^1 & x_n^2 & \cdots & x_n^n \end{bmatrix} \begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix}=\begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix} x00x10xn0x01x11xn1x02x12xn2x0nx1nxnna0a1an=y0y1yn

左边这个矩阵就是所谓的范德蒙德矩阵,记作 V \boldsymbol V V,系数列向量记作 A \boldsymbol A A,右边的记作 B \boldsymbol B B,则很明显:

V A = B \boldsymbol{VA} = \boldsymbol B VA=B

打开来看清楚些实际就是多项式 f f f 在每个点处的值:

y j = f ( x j ) = ∑ i = 0 n a i x j i y_j = f(x_j) = \sum_{i = 0}^na_ix_j^i yj=f(xj)=i=0naixji

我们把两边都乘上 V − 1 \boldsymbol V^{-1} V1

[ a 0 a 1 ⋮ a n ] = V − 1 [ y 0 y 1 ⋮ y n ] \begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix}=\boldsymbol V^{-1} \begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix} a0a1an=V1y0y1yn

就得到了 a i a_i ai 一定可以表示为某种形如

a k = ∑ [ ⋮ ] y k a_k = \sum \begin{bmatrix} \vdots \end{bmatrix}y_k ak=[]yk

的形式,即 a k a_k ak 只与 x i x_i xi y k y_k yk 有关。

所以不难发现对于一个要求的 f ( x ⊖ ) f(x_\ominus) f(x),都可以被表示为如下形式

f ( x ⊖ ) = ∑ δ k ( x ⊖ ) y k f(x_\ominus)=\sum\delta_k(x_\ominus)y_k f(x)=δk(x)yk

δ k ( x ) \delta_k(x) δk(x) 构造的过程即需要考虑 x = x k x=x_k x=xk δ j ( x ) = 0 ∧ δ k ( x ) = 1 \delta_j(x) = 0\land\delta_k(x) = 1 δj(x)=0δk(x)=1,其中 k ≠ j k\not=j k=j δ j ( x k ) = 0 \delta_j(x_k) = 0 δj(xk)=0 说明每一个 δ j \delta_j δj 都要有 ( x − x k ) (x-x_k) (xxk) 这个因式,然后又因为 δ k ( x k ) = 1 \delta_k(x_k) = 1 δk(xk)=1,所以最终构造出来就是上面的结果:

f ( x ) = ∑ i = 1 n y i ∏ j = 1 , j ≠ i n x − x j x i − x j f(x) = \sum_{i=1}^n y_i\prod_{j=1,j\ne i}^n\frac{x-x_j}{x_i-x_j} f(x)=i=1nyij=1,j=inxixjxxj

我们其实也可以利用拉格朗日插值来求范德蒙矩阵的逆阵,复杂度 O ( n 2 ) O(n^2) O(n2)

开始全家桶之前

形式化定义

约定: f i f_i fi 表示 f ( x ) f(x) f(x) x i x^i xi 处的系数,即一个多项式可以表示为 ∑ i = 0 f i x i \displaystyle\sum_{i = 0} f_ix^i i=0fixi 的形式。

两个多项式的加减法定义为

f ( x ) ± g ( x ) = ∑ i = 0 ( f i ± g i ) x i f(x) \pm g(x) = \sum_{i = 0}(f_i \pm g_i)x^i f(x)±g(x)=i=0(fi±gi)xi

复杂度 O ( n ) O(n) O(n)

两个多项式的乘法(加法卷积)定义为:

f ( x ) ∗ g ( x ) = ∑ i = 0 x i ∑ j = 0 f j g i − j f(x)*g(x) = \sum_{i = 0}x^i\sum_{j = 0}f_jg_{i - j} f(x)g(x)=i=0xij=0fjgij

不难发现其正确性。可以手动模拟一下多项式的乘法看看是不是这样子的。其本质也就是卷完之后合并同类项。朴素的做的话复杂度是 O ( n 2 ) O(n^2) O(n2) 的,下面要讲的 FFT/NTT 可以加速到 O ( n log ⁡ n ) O(n\log n) O(nlogn)

有些时候,题目只对多项式的前若干项感兴趣,所以我们给运算设定一个上界,即 ( m o d x n ) \pmod{x^n} (modxn)。意思就是只考虑这个多项式的前 n n n,从 x n x^n xn 开始以后的全部舍掉。

不难发现由加法和乘法是从低位到高位贡献的,所以

( f ( x )   m o d   x n ± g ( x )   m o d   x n )   m o d   x n = ( f ( x ) ± g ( x ) )   m o d   x n ( f ( x )   m o d   x n ) ∗ ( g ( x )   m o d   x n )   m o d   x n = ( f ( x ) ∗ g ( x ) )   m o d   x n \begin{aligned} (f(x) \bmod{x^n} \pm g(x)\bmod{x^n})\bmod{x^n} &= (f(x) \pm g(x))\bmod{x^n}\\ (f(x) \bmod{x^n}) * (g(x)\bmod{x^n})\bmod{x^n} &= (f(x) * g(x))\bmod{x^n}\\ \end{aligned} (f(x)modxn±g(x)modxn)modxn(f(x)modxn)(g(x)modxn)modxn=(f(x)±g(x))modxn=(f(x)g(x))modxn

下面我们就开始学习多项式的各种操作吧

快速傅里叶变换(FFT)

FFT 可以加速卷积,让时间复杂度从 O ( n 2 ) O(n^2) O(n2) 降到 O ( n log ⁡ n ) O(n\log n) O(nlogn),学习 FFT 的基础操作前,需要先了解复数,因为 FFT 就是基于单位复数根的良好性质实现的。

复数基础

(数学选修 2-2 内容)

定义虚数单位 i 2 = − 1 \mathrm i^2 = \sqrt{-1} i2=1 ,把形如 a + b i   ( a , b ∈ R ) a + b\mathrm i\:(a,b\in\mathbb R) a+bi(a,bR) 的数称为复数,所有复数的集合称为复数集 C \mathbb C C

复数一般使用 z z z 表示,表示为 z = a + b i z = a + b\mathrm i z=a+bi,这种形式称为复数的代数形式。 a a a 被称为复数的实部, b b b 称为复数的虚部,未加说明的情况下一般认为 a , b ∈ R a,b\in\mathbb R a,bR。很明显地,当 a = 0 ∧ b ≠ 0 a = 0\land b\not=0 a=0b=0 时,这个复数为纯虚数,当 b = 0 b=0 b=0 时,这个复数为实数。

每个复数 a + b i a + b\mathrm i a+bi 都能对应平面直角坐标系里面的一个点 ( a , b ) (a,b) (a,b),同样的也可以对应一个向量 ( a , b ) (a,b) (a,b)。故定义复数的模为 a 2 + b 2 \sqrt{a^2 + b^2} a2+b2

定义复数的加法与乘法:
( a + b i ) + ( c + d i ) = ( a + c ) + ( b + d ) i \begin{aligned} &(a + b\mathrm i) + (c + d\mathrm i)\\ =&(a + c) + (b + d)\mathrm i \end{aligned} =(a+bi)+(c+di)(a+c)+(b+d)i

( a + b i ) ( c + d i ) = a c + a d i + c b i + b d i 2 = ( a c − b d ) + ( a d + b c ) i \begin{aligned} &(a+b\mathrm i)(c + d\mathrm i)\\ =&ac + ad\mathrm i + cb\mathrm i + bd\mathrm i^2\\ =&(ac - bd) + (ad + bc)\mathrm i \end{aligned} ==(a+bi)(c+di)ac+adi+cbi+bdi2(acbd)+(ad+bc)i

这都是比较显然的。

容易看出复数满足很多实数的运算律。

定义复数 z = a + b i z=a+b\mathrm i z=a+bi 的共轭复数为 z ‾ = a − b i \overline{z} = a - b\mathrm i z=abi,不难发现 z z z z ‾ \overline{z} z 关于实轴对称。
z z ‾ = ( a + b i ) ( a − b i ) = a 2 + b 2 = ∣ z ∣ 2 z\overline z=(a+b\mathrm i)(a-b\mathrm i) = a^2 + b^2=|z|^2 zz=(a+bi)(abi)=a2+b2=z2
复数既然可以对应平面直角坐标系中的向量,不难发现其可以使用其模长与辐角来表示:
z = a + b i    ⟺    z = r ( cos ⁡ θ + i sin ⁡ θ ) z=a+b\mathrm i\iff z = r(\cos\theta+\mathrm i\sin\theta) z=a+biz=r(cosθ+isinθ)
其中 r r r z z z 的模长, θ \theta θ 为其辐角。即我们可以把一个复数表示成二元组 ( r , θ ) (r,\theta) (r,θ) 的形式。

现在考虑两个复数 ( r 1 , θ 1 ) (r_1,\theta_1) (r1,θ1) ( r 2 , θ 2 ) (r_2,\theta_2) (r2,θ2) 相乘得到的结果:
( r 1 , θ 1 ) × ( r 2 , θ 2 ) = r 1 ( cos ⁡ θ 1 + i sin ⁡ θ 1 ) r 2 ( cos ⁡ θ 2 + i sin ⁡ θ 2 ) = r 1 r 2 ( cos ⁡ θ 1 cos ⁡ θ 2 − sin ⁡ θ 1 sin ⁡ θ 2 + i sin ⁡ θ 1 cos ⁡ θ 2 + i sin ⁡ θ 2 cos ⁡ θ 1 ) = r 1 r 2 ( cos ⁡ ( θ 1 + θ 2 ) + i sin ⁡ ( θ 1 + θ 2 ) ) = ( r 1 r 2 , θ 1 + θ 2 ) \begin{aligned} (r_1,\theta_1)\times(r_2,\theta_2) &= r_1(\cos\theta_1 + \mathrm i\sin\theta_1)r_2(\cos\theta_2 + \mathrm i\sin\theta_2)\\ &=r_1r_2(\cos\theta_1\cos\theta_2 - \sin\theta_1\sin\theta_2 + \mathrm i\sin\theta_1\cos\theta_2 + \mathrm i\sin\theta_2\cos\theta_1)\\ &=r_1r_2\left(\cos(\theta_1 + \theta_2) + \mathrm i\sin(\theta_1 + \theta_2)\right)\\ &=(r_1r_2,\theta_1 + \theta_2) \end{aligned} (r1,θ1)×(r2,θ2)=r1(cosθ1+isinθ1)r2(cosθ2+isinθ2)=r1r2(cosθ1cosθ2sinθ1sinθ2+isinθ1cosθ2+isinθ2cosθ1)=r1r2(cos(θ1+θ2)+isin(θ1+θ2))=(r1r2,θ1+θ2)
于是我们可以概括复数乘法的法则:模长相乘,辐角相加。(上述推导需要掌握基本的三角恒等变换)

从欧拉公式到单位圆

给出复数指数幂的定义:
e x + y i = e x ( cos ⁡ y + i sin ⁡ y ) \mathrm e^{x +y\mathrm i} = e^x(\cos y + \mathrm i\sin y) ex+yi=ex(cosy+isiny)
这个公式是由我也不会证明的泰勒展开推导出来的:
sin ⁡ ( x ) = x − x 3 3 ! + x 5 5 ! − x 7 7 ! + x 9 9 ! + ⋯ = ∑ k = 1 ∞ ( − 1 ) k − 1 x 2 k − 1 ( 2 k − 1 ) ! cos ⁡ ( x ) = 1 − x 2 2 ! + x 4 4 ! − x 6 6 ! + x 8 8 ! + ⋯ = ∑ k = 0 ∞ ( − 1 ) k x 2 k ( 2 k ) ! e x = 1 + x + x 2 2 ! + x 3 3 ! + x 4 4 ! + ⋯ = ∑ k = 0 ∞ x k k ! \begin{aligned} \sin(x) &= x - \frac{x^3}{3!}+\frac{x^5}{5!} - \frac{x^7}{7!} + \frac{x^9}{9!} + \cdots = \sum_{k = 1}^\infin\frac{(-1)^{k - 1}x^{2k - 1}}{(2k-1)!}\\ \cos(x) &= 1 - \frac{x^2}{2!} + \frac{x^4}{4!} - \frac{x^6}{6!} + \frac{x^8}{8!} + \cdots = \sum_{k = 0}^\infin\frac{(-1)^{k} x^{2k}}{(2k)!}\\ \mathrm e^x &= 1 + x + \frac{x^2}{2!} + \frac{x^3}{3!} + \frac{x^4}{4!} + \cdots = \sum_{k = 0}^\infin\frac{x^k}{k!} \end{aligned} sin(x)cos(x)ex=x3!x3+5!x57!x7+9!x9+=k=1(2k1)!(1)k1x2k1=12!x2+4!x46!x6+8!x8+=k=0(2k)!(1)kx2k=1+x+2!x2+3!x3+4!x4+=k=0k!xk
x + y i x + y\mathrm i x+yi 代入进去即可推导。

如果 x = 0 x = 0 x=0,我们就得到大名鼎鼎的欧拉公式:
e x i = cos ⁡ x + i sin ⁡ x \mathrm e^{x\mathrm i} = \cos x + \mathrm i\sin x exi=cosx+isinx
更特殊地,如果 x = π x = \pi x=π,得到的就是下面这个神奇的式子:
e π i = − 1 \mathrm e^{\pi\mathrm i} = -1 eπi=1
复平面上我们可以定义类似于平面直角坐标系上的单位圆,单位圆上的所有复数构成集合 { z : ∣ z ∣ = 1 } \{z: |z| = 1\} {z:z=1}。这些复数都可以表示为 cos ⁡ θ + i sin ⁡ θ \cos\theta + \mathrm i\sin\theta cosθ+isinθ e θ i e^{\theta \mathrm i} eθi 的形式。

多项式的表示法

系数表示法:顾名思义
f ( x ) = a 0 + a 1 x + a 2 x 2 + ⋯ + a n x n    ⟺    f ( x ) = { a 0 , a 1 , a 2 , ⋯   , a n } = [ x 0 x 1 x 2 ⋯ x n ] [ a 0 a 1 a 2 ⋮ a n ] f(x) = a_0 + a_1x + a_2x^2 + \cdots + a_nx^n\iff f(x) = \{a_0,a_1,a_2,\cdots,a_n\} = \begin{bmatrix} x^0 & x^1 & x^2 &\cdots & x^n \end{bmatrix} \begin{bmatrix} a_0\\a_1\\a_2\\\vdots\\a_n \end{bmatrix} f(x)=a0+a1x+a2x2++anxnf(x)={a0,a1,a2,,an}=[x0x1x2xn]a0a1a2an
点值表示法:

我们知道由一个多项式在 n + 1 n + 1 n+1 个点上的取值是可以唯一确定一个多项式的,其本质也就是线性方程组的解。所以一个 n n n 次多项式可以用 n + 1 n + 1 n+1 个点表示:

f ( x ) = { ( x 0 , y 0 ) , ( x 1 , y 1 ) , ⋯   , ( x n , y n ) } f(x) = \{(x_0,y_0),(x_1,y_1),\cdots,(x_n,y_n)\} f(x)={(x0,y0),(x1,y1),,(xn,yn)}

或者:

[ x 0 0 x 0 1 x 0 2 ⋯ x 0 n x 1 0 x 1 1 x 1 2 ⋯ x 1 n ⋮ ⋮ ⋮ ⋮ x n 0 x n 1 x n 2 ⋯ x n n ] [ a 0 a 1 ⋮ a n ] = [ y 0 y 1 ⋮ y n ] \begin{bmatrix} x_0^0 & x_0^1 & x_0^2 &\cdots &x_0^n\\ x_1^0 & x_1^1 & x_1^2 &\cdots & x_1^n\\ \vdots & \vdots & \vdots & &\vdots\\ x_n^0 & x_n^1 & x_n^2 & \cdots & x_n^n \end{bmatrix}\begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix} =\begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix} x00x10xn0x01x11xn1x02x12xn2x0nx1nxnna0a1an=y0y1yn

通过下面的这个形式我们看得出来其就是一个典型的线性方程组的形式,不难证明其解的唯一性。

并且我们发现点值表示法有一个很明显的优势:可以在 O ( n ) O(n) O(n) 的时间内将两个多项式乘起来,只需把对应点的 y y y 乘起来即可。

通俗的来说,FFT 实现的就是快速求多项式乘法的过程:先把系数表示法转成点值表示法(DFT,离散傅里叶变换),乘完之后再把点值还原为插值(IDFT,离散傅里叶逆变换)。可是朴素的 DFT 需要的时间复杂度为 O ( n 2 ) O(n^2) O(n2),IDFT 还回其系数需要高斯消元是 O ( n 3 ) O(n^3) O(n3) 的。而 FFT 利用了一些很特殊很特殊的值加速了 DFT 和 IDFT 的过程,使得总时间复杂度降低到了 O ( n log ⁡ n ) O(n\log n) O(nlogn)

单位复数根

解这个方程:
x n = 1 x^n = 1 xn=1
我们会发现这个方程在实数范围内只有 1 1 1 或者 2 2 2 个解。然而代数基本定理告诉我们这样的方程有 n n n 个复数域上的解。由模长相乘辐角相加我们知道因为最终 x n = 1 x^n = 1 xn=1,所以这些满足条件的 x x x 的模长必定也是 1 1 1。然后需要满足他们的辐角的 n n n 倍能被 2 π 2\pi 2π 整除。

不难发现其就是 n n n 等分单位圆:

img

我们记 n n n 次单位根的第 k k k 个记为 ω n k \omega_n^k ωnk,不难发现 ω k n = e 2 k π i n \omega_k^n = \mathrm e^{\frac{2k\pi i}{n}} ωkn=en2kπi。由此可见,单位复数根具有一些非常好的性质比如:
ω n 0 = ω n n = 1 ω n k = ω 2 n 2 k ω 2 n k + n = − ω 2 n k ( ω 2 n k + n ) 2 = ω n k \begin{aligned} \omega_n^0 = \omega_n^n &= 1\\ \omega_n^k &= \omega_{2n}^{2k}\\ \omega_{2n}^{k + n} &= -\omega_{2n}^k\\ \left(\omega_{2n}^{k + n}\right)^2 &=\omega_n^k \end{aligned} ωn0=ωnnωnkω2nk+n(ω2nk+n)2=1=ω2n2k=ω2nk=ωnk
利用这些性质,我们可以加速 DFT 的过程。FFT 就是利用分治思想加速求每个 f ( ω n k ) f(\omega_n^k) f(ωnk) 的值

DFT

此时 DFT 的分治思想就是分开考虑奇次项和偶次项:

考虑
f ( x ) = a 0 x 0 + a 1 x 1 + a 2 x 2 + ⋯ f(x) = a_0x^0 + a_1x^1 + a_2x^2 + \cdots f(x)=a0x0+a1x1+a2x2+
将其分为两个多项式
f ( x ) = a 0 x 0 + a 2 x 2 + a 4 x 4 + a 6 x 6 + a 8 x 8 + ⋯ + a 1 x 1 + a 3 x 3 + a 5 x 5 + a 7 x 7 + a 9 x 9 + ⋯ = a 0 x 0 + a 2 x 2 + a 4 x 4 + a 6 x 6 + a 8 x 8 + ⋯ + x ( a 1 x 0 + a 3 x 2 + a 5 x 4 + a 7 x 6 + ⋯   ) \begin{aligned} f(x) &= a_0x^0 + a_2x^2 + a_4x^4 + a_6x^6 + a_8x^8 + \cdots +a_1x^1 + a_3x^3 + a_5x^5 + a_7x^7 + a_9x^9 + \cdots\\ &= a_0x^0 + a_2x^2 + a_4x^4 + a_6x^6 + a_8x^8+\cdots +x(a_1x^0 + a_3x^2 + a_5x^4 + a_7x^6 + \cdots) \end{aligned} f(x)=a0x0+a2x2+a4x4+a6x6+a8x8++a1x1+a3x3+a5x5+a7x7+a9x9+=a0x0+a2x2+a4x4+a6x6+a8x8++x(a1x0+a3x2+a5x4+a7x6+)
考虑两个新多项式:
f 0 ( x ) = a 0 x 0 + a 2 x 1 + a 4 x 2 + a 6 x 3 + ⋯ f 1 ( x ) = a 1 x 0 + a 3 x 1 + a 5 x 2 + a 7 x 3 + ⋯ \begin{aligned} f_0(x) &= a_0x^0 + a_2x^1 + a_4x^2 + a_6x^3 + \cdots\\ f_1(x) &= a_1x^0 + a_3x^1 + a_5x^2 + a_7x^3 + \cdots \end{aligned} f0(x)f1(x)=a0x0+a2x1+a4x2+a6x3+=a1x0+a3x1+a5x2+a7x3+
不难发现
f ( x ) = f 0 ( x 2 ) + x f 1 ( x 2 ) f(x) = f_0(x^2) + xf_1(x^2) f(x)=f0(x2)+xf1(x2)
利用单位复数根的性质:
D F T ( f ( ω n k ) ) = D F T ( f 0 ( ω n 2 k ) ) + ω n k D F T ( f 1 ( ω n 2 k ) ) = D F T ( f 0 ( ω n 2 k ) ) + ω n k D F T ( f 1 ( ω n 2 k ) ) \begin{aligned} \mathrm{DFT}(f(\omega_n^k)) &= \mathrm{DFT}(f_0(\omega_n^{2k})) + \omega_n^k\mathrm{DFT}(f_1(\omega_n^{2k}))\\ &=\mathrm{DFT}(f_0(\omega_\frac n2^k)) + \omega_n^k\mathrm{DFT}(f_1(\omega_\frac n2^k)) \end{aligned} DFT(f(ωnk))=DFT(f0(ωn2k))+ωnkDFT(f1(ωn2k))=DFT(f0(ω2nk))+ωnkDFT(f1(ω2nk))

D F T ( f ( ω n k + n 2 ) ) = D F T ( f 0 ( ω n 2 k + n ) ) + ω n k + n 2 D F T ( f 1 ( ω n 2 k + n ) ) = D F T ( f 0 ( ω n n ω n 2 k ) ) − ω n k D F T ( f 1 ( ω n n ω n 2 k ) ) = D F T ( f 0 ( ω n 2 k ) ) − ω n k D F T ( f 1 ( ω n 2 k ) ) \begin{aligned} \mathrm{DFT}(f(\omega_n^{k + \frac n2})) &= \mathrm{DFT}(f_0(\omega_n^{2k + n})) + \omega_{n}^{k + \frac n2}\mathrm{DFT}(f_1(\omega_n^{2k + n}))\\ &=\mathrm{DFT}(f_0(\omega_n^n\omega_n^{2k})) - \omega_n^k\mathrm{DFT}(f_1(\omega_n^n\omega_n^{2k}))\\ &=\mathrm{DFT}(f_0(\omega_\frac n2^k)) - \omega_n^k\mathrm{DFT}(f_1(\omega_\frac n2^k)) \end{aligned} DFT(f(ωnk+2n))=DFT(f0(ωn2k+n))+ωnk+2nDFT(f1(ωn2k+n))=DFT(f0(ωnnωn2k))ωnkDFT(f1(ωnnωn2k))=DFT(f0(ω2nk))ωnkDFT(f1(ω2nk))

其中 k < n 2 k < \displaystyle\frac n2 k<2n。不难发现只要我们求得出 D F T ( f 0 ( ω n 2 k ) ) \mathrm{DFT}(f_0(\omega_\frac n2^k)) DFT(f0(ω2nk)) D F T ( f 1 ( ω n 2 k ) ) \mathrm{DFT}(f_1(\omega_\frac n2^k)) DFT(f1(ω2nk)) 的话,就可以同时求出 D F T ( f ( ω n k ) ) \mathrm{DFT}(f(\omega_n^k)) DFT(f(ωnk)) D F T ( f ( ω n k + n 2 ) ) \mathrm{DFT}(f(\omega_n^{k + \frac n2})) DFT(f(ωnk+2n))。接下来再对 f 0 f_0 f0 f 1 f_1 f1 递归 DFT 即可。其时间复杂度函数是形如下面这样的:
T ( n ) = T ( n / 2 ) + O ( n ) T(n) = T(n/2) + O(n) T(n)=T(n/2)+O(n)
所以总复杂度为 Θ ( n log ⁡ n ) \Theta(n\log n) Θ(nlogn)

实际实现的时候一定要注意传进去的系数一定要是 2 m 2^m 2m 个的,不然分治的过程中左右不一样会出问题。第一次传进去的时候就高位补 0 0 0,补成最高项次数为 2 m − 1 2^{m - 1} 2m1 的多项式。

void dft(int lim, complex *a)
{
    if (lim == 1) return;//常数项直接返回
    complex a1[lim >> 1], a2[lim >> 1];
    for (int i = 0; i < lim; i += 2)
        a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];//把系数按照奇偶分开
    dft(lim >> 1, a1, type);//求 DFT(f_0())
    dft(lim >> 1, a2, type);//求 DFT(f_1())
    complex Wn = complex(cos(2.0 * pi / lim), sin(2.0 * pi / lim)), w = complex(1, 0);
    for (int i = 0; i < (lim >> 1); ++i, w = w * Wn)
    {
        a[i] = a1[i] + w * a2[i];//求 DFT(f(\omega_n^k))
        a[i + (lim >> 1)] = a1[i] - w * a2[i];//求 DFT(f(\omega_n^{k+\fracn2}))
    }
    return;
}

IDFT

好了现在假装我们已经求出了两个多项式的点值表达并已经将他们乘起来,但是我们最终还是要把他还原回去到系数表示的。这个过程就叫做 IDFT。

其实就是我们需要求解下面关于 a a a 的线性方程组:

[ ( ω n 0 ) 0 ( ω n 0 ) 1 ( ω n 0 ) 2 ⋯ ( ω n 0 ) n ( ω n 1 ) 0 ( ω n 1 ) 1 ( ω n 1 ) 2 ⋯ ( ω n 1 ) n ⋮ ⋮ ⋮ ⋮ ( ω n n ) 0 ( ω n n ) 1 ( ω n n ) 2 ⋯ ( ω n n ) n ] [ a 0 a 1 ⋮ a n ] = [ y 0 y 1 ⋮ y n ] \begin{bmatrix} (\omega_n^0)^0 & (\omega_n^0)^1 & (\omega_n^0)^2 &\cdots &(\omega_n^0)^n\\ (\omega_n^1)^0 & (\omega_n^1)^1 & (\omega_n^1)^2 &\cdots & (\omega_n^1)^n\\ \vdots & \vdots & \vdots & &\vdots\\ (\omega_n^{n})^0 & (\omega_n^{n})^1 & (\omega_n^{n})^2 & \cdots & (\omega_n^n)^n \end{bmatrix} \begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix}=\begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix} (ωn0)0(ωn1)0(ωnn)0(ωn0)1(ωn1)1(ωnn)1(ωn0)2(ωn1)2(ωnn)2(ωn0)n(ωn1)n(ωnn)na0a1an=y0y1yn

我们将其乘上左边矩阵的逆:

[ a 0 a 1 ⋮ a n ] = [ ( ω n 0 ) 0 ( ω n 0 ) 1 ( ω n 0 ) 2 ⋯ ( ω n 0 ) n ( ω n 1 ) 0 ( ω n 1 ) 1 ( ω n 1 ) 2 ⋯ ( ω n 1 ) n ⋮ ⋮ ⋮ ⋮ ( ω n n ) 0 ( ω n n ) 1 ( ω n n ) 2 ⋯ ( ω n n ) n ] − 1 [ y 0 y 1 ⋮ y n ] \begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix}=\begin{bmatrix} (\omega_n^0)^0 & (\omega_n^0)^1 & (\omega_n^0)^2 &\cdots &(\omega_n^0)^n\\ (\omega_n^1)^0 & (\omega_n^1)^1 & (\omega_n^1)^2 &\cdots & (\omega_n^1)^n\\ \vdots & \vdots & \vdots & &\vdots\\ (\omega_n^{n})^0 & (\omega_n^{n})^1 & (\omega_n^{n})^2 & \cdots & (\omega_n^n)^n \end{bmatrix}^{-1} \begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix} a0a1an=(ωn0)0(ωn1)0(ωnn)0(ωn0)1(ωn1)1(ωnn)1(ωn0)2(ωn1)2(ωnn)2(ωn0)n(ωn1)n(ωnn)n1y0y1yn

模相同的正交列向量构成的矩阵的逆是转置的模分之一倍,所以:

[ ( ω n 0 ) 0 ( ω n 0 ) 1 ( ω n 0 ) 2 ⋯ ( ω n 0 ) n ( ω n 1 ) 0 ( ω n 1 ) 1 ( ω n 1 ) 2 ⋯ ( ω n 1 ) n ⋮ ⋮ ⋮ ⋮ ( ω n n ) 0 ( ω n n ) 1 ( ω n n ) 2 ⋯ ( ω n n ) n ] − 1 = 1 n + 1 [ ( ω n − 0 ) 0 ( ω n − 0 ) 1 ( ω n − 0 ) 2 ⋯ ( ω n − 0 ) n ( ω n − 1 ) 0 ( ω n − 1 ) 1 ( ω n − 1 ) 2 ⋯ ( ω n − 1 ) n ⋮ ⋮ ⋮ ⋮ ( ω n − n ) 0 ( ω n − n ) 1 ( ω n − n ) 2 ⋯ ( ω n − n ) n ] \begin{bmatrix} (\omega_n^0)^0 & (\omega_n^0)^1 & (\omega_n^0)^2 &\cdots &(\omega_n^0)^n\\ (\omega_n^1)^0 & (\omega_n^1)^1 & (\omega_n^1)^2 &\cdots & (\omega_n^1)^n\\ \vdots & \vdots & \vdots & &\vdots\\ (\omega_n^{n})^0 & (\omega_n^{n})^1 & (\omega_n^{n})^2 & \cdots & (\omega_n^n)^n \end{bmatrix}^{-1} =\frac {1}{n+1} \begin{bmatrix} (\omega_n^{-0})^0 & (\omega_n^{-0})^1 & (\omega_n^{-0})^2 &\cdots &(\omega_n^{-0})^n\\ (\omega_n^{-1})^0 & (\omega_n^{-1})^1 & (\omega_n^{-1})^2 &\cdots & (\omega_n^{-1})^n\\ \vdots & \vdots & \vdots & &\vdots\\ (\omega_n^{-n})^0 & (\omega_n^{-n})^1 & (\omega_n^{-n})^2 & \cdots & (\omega_n^{-n})^n \end{bmatrix} (ωn0)0(ωn1)0(ωnn)0(ωn0)1(ωn1)1(ωnn)1(ωn0)2(ωn1)2(ωnn)2(ωn0)n(ωn1)n(ωnn)n1=n+11(ωn0)0(ωn1)0(ωnn)0(ωn0)1(ωn1)1(ωnn)1(ωn0)2(ωn1)2(ωnn)2(ωn0)n(ωn1)n(ωnn)n

所以不难发现,IDFT 其实就是再做了一遍 DFT,只不过是反起来的。只是算出来最后的系数结果都要除以点值的个数,反应在代码里面就是那个 lim 变量。

不难发现 ω n k \omega_n^k ωnk 的共轭就是虚部取反,所以可以在 DFT 函数里面传一个参数表示是否为 IDFT。

这样子一个递归版的 FFT 就写完了,总体的代码如下:

#include <cstdio>
#include <cctype>
#include <cmath>
#define FOR(i, a, b) for (int i = a; i <= b; ++i)

const int maxn = 2e6 + 5;
const double pi = acos(-1.0);

inline int read()
{
    char c = getchar();
    int s = 0;
    while (!isdigit(c))
        c = getchar();
    while (isdigit(c))
        s = 10 * s + c - '0', c = getchar();
    return s;
}

struct complex
{
    double x, y;
    complex(double xx = 0, double yy = 0)
    {
        x = xx, y = yy;
    }
} a[maxn], b[maxn];

complex operator+(const complex &a, const complex &b) {return complex(a.x + b.x, a.y + b.y);}
complex operator-(const complex &a, const complex &b) {return complex(a.x - b.x, a.y - b.y);}
complex operator*(const complex &a, const complex &b) {return complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}

void dft(int lim, complex *a, int type)//type = 1 DFT;type = -1 IDFT
{
    if (lim == 1) return;//返回常数项
    complex a1[lim >> 1], a2[lim >> 1];
    for (int i = 0; i < lim; i += 2)
        a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];
    dft(lim >> 1, a1, type);
    dft(lim >> 1, a2, type);
    complex Wn = complex(cos(2.0 * pi / lim), type * sin(2.0 * pi / lim)), w = complex(1, 0);
    for (int i = 0; i < (lim >> 1); ++i, w = w * Wn)
    {
        a[i] = a1[i] + w * a2[i];
        a[i + (lim >> 1)] = a1[i] - w * a2[i];
    }
    return;
}

int main()
{
    int n = read(), m = read();
    FOR(i, 0, n) a[i].x = read();
    FOR(i, 0, m) b[i].x = read();
    int lim = 1;
    while (lim <= n + m) lim <<= 1;//lim一定要大于 n + m
    dft(lim, a, 1);
    dft(lim, b, 1);
    FOR(i, 0, lim)
        a[i] = a[i] * b[i];//点值乘起来
    dft(lim, a, -1);//IDFT还回去
    FOR(i, 0, n + m)
        printf("%d ", (int)(a[i].x / lim + 0.5));//最后要除那个数然后还原回去,四舍五入
    return 0;
}

位逆序置换

然而,上面的代码连模板都跑不过去……

考虑继续优化 DFT 的过程。递归的过程中开了大量的空间并且常数巨大,考虑非递归写法。

只考虑我们对 0 0 0 7 7 7 操作:

递归的过程:

original		0	1	2	3	4	5	6	7
recursion#1		0	2	4	6	1	3	5	7
recursion#2		0	4	2	6	1	5	3	7
recursion#3		0	4	2	6	1	5	3	7
original bin	000	001	010	011	100	101	110	111
now bin			000	100	010	110	001	101	011	111

可见递归到最后的结果无非就是一个二进制反转。

所以我们可以考虑非递归,一开始就先把所有的数放到最后的位置,然后迭代的时候一步步还回去即可。这个过程就是位逆序置换(蝴蝶变换)

考虑处理出 x x x 二进制位翻转之后的数 R ( x ) R(x) R(x)。易知 R ( 0 ) = 0 R(0) = 0 R(0)=0。我们可以从小到大求 R ( x ) R(x) R(x)。很明显, ⌊ x / 2 ⌋ \lfloor x/2\rfloor x/2 的二进制位是 x x x 右移一位,那么如果知道了 R ( ⌊ x / 2 ⌋ ) R(\lfloor x/2\rfloor) R(x/2) 就可以很容易的求出 R ( x ) R(x) R(x),再分 x x x 的奇偶性判断就可以了。
R ( x ) = ⌊ R ( ⌊ x / 2 ⌋ ) 2 ⌋ + ( x   m o d   2 ) × l e n 2 R(x) = \left\lfloor\frac{R(\lfloor x/2\rfloor)}{2}\right\rfloor + (x\bmod 2)\times\frac{len}2 R(x)=2R(x/2)+(xmod2)×2len
举个例子:翻转 ( 10101110 ) 2 (10101110)_2 (10101110)2,首先我们知道它的二分之一倍为 ( 01010111 ) 2 (01010111)_2 (01010111)2,其翻转结果为 ( 11101010 ) 2 (11101010)_2 (11101010)2,除以二变为 ( 01110101 ) 2 (01110101)_2 (01110101)2,由于它是偶数所以前面不用补 1 1 1。不难发现其就是一开始要求的翻转结果。

预处理翻转结果的代码:

while (lim <= n + m) lim <<= 1;
FOR(i, 0, lim - 1)
    rev[i] = ((rev[i >> 1] >> 1) | (((i & 1) ? (lim >> 1) : 0)));

然后在处理翻转的时候只需要下面几行:

FOR(i, 0, lim - 1)
    if (i < rev[i])
        myswap(a[i], a[rev[i]]);

不难验证其正确性。

而且观察我们在求 D F T ( f ( ω n k ) ) \mathrm{DFT}(f(\omega_n^k)) DFT(f(ωnk)) 时我们需要算两遍 ω n k D F T ( f 1 ( ω n 2 k ) ) \omega_n^k\mathrm{DFT}(f_1(\omega_\frac n2^k)) ωnkDFT(f1(ω2nk)),复数的乘法常数很大,考虑使用临时变量记录以降低常数。

这样子的话迭代版的 DFT 过程就很好写了:

void DFT(int lim, complex *a, int type)
{
    FOR(i, 0, lim - 1)
        if (i < rev[i])
            myswap(a[i], a[rev[i]]);//先预处理翻转完了的结果
    for (int p = 2; p <= lim; p <<= 1)//模拟合并答案的过程,即为所谓的 n
    {
        int len = p >> 1;//即上面的 n / 2
        complex Wp = complex(cos(2 * pi / p), type * sin(2 * pi / p));//处理出 p 次单位根
        for (int k = 0; k < lim; k += p)//对每一个进行合并
        {
            complex w = complex(1, 0);//处理 \omega_p^0
            for (int l = k; l < k + len; ++l, w = w * Wp)//开始合并
            {
                //此时的 a[l] 就是之前的 a1[i],a[len + l] 就是之前的 a2[i]
                complex tmp = w * a[len + l];
                a[len + l] = a[l] - tmp;//相当于上面的 a[i + (lim >> 1)] = a1[i] - w * a2[i]
                a[l] = a[l] + tmp;//相当于上面的 a[i] = a1[i] + w * a2[i]
            }
        }
    }
}

多项式乘法的实现

总的一个非递归版 FFT 的实现如下(洛谷 P3803):

#include <cstdio>
#include <cctype>
#include <cmath>
#define FOR(i, a, b) for (int i = a; i <= b; ++i)

const int maxn = 3e6 + 5;
const double pi = acos(-1.0);

inline int read()
{
    char c = getchar();
    int s = 0;
    while (!isdigit(c))
        c = getchar();
    while (isdigit(c))
        s = 10 * s + c - '0', c = getchar();
    return s;
}

template<typename T> inline void myswap(T &a, T &b)
{
    T t = a;
    a = b;
    b = t;
    return;
}

struct complex
{
    double x, y;
    complex(double xx = 0, double yy = 0)
    {
        x = xx, y = yy;
    }
} a[maxn], b[maxn];

int rev[maxn];

complex operator+(const complex &a, const complex &b) {return complex(a.x + b.x, a.y + b.y);}
complex operator-(const complex &a, const complex &b) {return complex(a.x - b.x, a.y - b.y);}
complex operator*(const complex &a, const complex &b) {return complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}

void DFT(int lim, complex *a, int type)
{
    FOR(i, 0, lim - 1)
        if (i < rev[i])
            myswap(a[i], a[rev[i]]);//先预处理翻转完了的结果
    for (int p = 2; p <= lim; p <<= 1)//模拟合并答案的过程,即为所谓的 n
    {
        int len = p >> 1;//即上面的 n / 2
        complex Wp = complex(cos(2 * pi / p), type * sin(2 * pi / p));//处理出 p 次单位根
        for (int k = 0; k < lim; k += p)//对每一个进行合并
        {
            complex w = complex(1, 0);//处理 \omega_p^0
            for (int l = k; l < k + len; ++l, w = w * Wp)//开始合并
            {
                //此时的 a[l] 就是之前的 a1[i],a[len + l] 就是之前的 a2[i]
                complex tmp = w * a[len + l];
                a[len + l] = a[l] - tmp;//相当于上面的 a[i + (lim >> 1)] = a1[i] - w * a2[i]
                a[l] = a[l] + tmp;//相当于上面的 a[i] = a1[i] + w * a2[i]
            }
        }
    }
}

int main()
{
    int n = read(), m = read();
    FOR(i, 0, n) a[i].x = read();
    FOR(i, 0, m) b[i].x = read();
    int lim = 1;
    while (lim <= n + m) lim <<= 1;//补齐高位
    FOR(i, 0, lim - 1)
        rev[i] = ((rev[i >> 1] >> 1) | (((i & 1) ? (lim >> 1) : 0)));//先处理翻转完的结果
    DFT(lim, a, 1);//DFT
    DFT(lim, b, 1);//DFT
    FOR(i, 0, lim)
        a[i] = a[i] * b[i];//对处理出来的点值进行乘法
    DFT(lim, a, -1);//IDFT
    FOR(i, 0, n + m)
        printf("%d ", (int)(a[i].x / lim + 0.5));
    return 0;
}

使用 FFT 来求高精度整数乘法的实现(洛谷 P1919):

#include <cstdio>
#include <cstring>
#include <cmath>
#define FOR(i, a, b) for (int i = a; i <= b; ++i)
#define DEC(i, a, b) for (int i = a; i >= b; --i)

template<typename T> inline void myswap(T &a, T &b) {T t = a; a = b; b = t; return;}

typedef double db;

const int maxn = 3000000 + 5;
const db pi = acos(-1.0);

struct cmplx
{
    db x, y;
    cmplx(db xx = 0, db yy = 0) {x = xx, y = yy;}
} a[maxn], b[maxn];

cmplx operator+(const cmplx &a, const cmplx &b) {return cmplx(a.x + b.x, a.y + b.y);}
cmplx operator-(const cmplx &a, const cmplx &b) {return cmplx(a.x - b.x, a.y - b.y);}
cmplx operator*(const cmplx &a, const cmplx &b) {return cmplx(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}

char s1[maxn], s2[maxn];
int rev[maxn], ans[maxn];

void DFT(cmplx *f, int lim, int type)
{
    FOR(i, 0, lim - 1)
        if (i < rev[i])
            myswap(f[i], f[rev[i]]);
    for (int p = 2; p <= lim; p <<= 1)
    {
        int len = p >> 1;
        cmplx Wp(cos(2.0 * pi / p), type * sin(2.0 * pi / p));
        for (int k = 0; k < lim; k += p)
        {
            cmplx w(1, 0);
            for (int l = k; l < k + len; ++l, w = w * Wp)
            {
                cmplx tmp = w * f[l + len];
                f[l + len] = f[l] - tmp;
                f[l] = f[l] + tmp;
            }
        }
    }
    return;
}

int main()
{
    scanf("%s\n%s", s1, s2);
    int n1 = -1, n2 = -1;
    DEC(i, strlen(s1) - 1, 0) a[++n1].x = s1[i] - '0';
    DEC(i, strlen(s2) - 1, 0) b[++n2].x = s2[i] - '0';
    int lim = 1;
    while (lim <= n1 + n2) lim <<= 1;
    FOR(i, 0, lim - 1)
        rev[i] = ((rev[i >> 1] >> 1) | (((i & 1) ? (lim >> 1) : 0)));
    DFT(a, lim, 1);
    DFT(b, lim, 1);
    FOR(i, 0, lim)
        a[i] = a[i] * b[i];
    DFT(a, lim, -1);
    FOR(i, 0, lim)
        ans[i] = (int)(a[i].x / lim + 0.5);
    FOR(i, 0, lim)
        if (ans[i] >= 10) ans[i + 1] += ans[i] / 10, ans[i] %= 10, lim += (i == lim);
    while (!ans[lim] && lim > -1) --lim;
    if (lim == -1) puts("0");
    else DEC(i, lim, 0) printf("%d", ans[i]);
    return 0;
}

当然,千万要记得 IDFT 还回去的时候要除以 lim,实在怕记不住就在 DFT 函数里面加几句话直接处理好

if (type == -1)
    FOR(i, 0, lim - 1)
        f[i].x /= lim;

针对多项式乘法:三次变两次优化

我们发现我们在做多项式乘法的时候,需要先 DFT A ( x ) A(x) A(x) B ( x ) B(x) B(x),乘在一起之后再 IDFT 还回来 C ( x ) C(x) C(x),一共进行了三次这样的操作。考虑如何减少我们调用 DFT 的次数。

可以把 B ( x ) B(x) B(x) 的系数放到 A ( x ) A(x) A(x) 系数的虚部上面,即 a + b i a + b\mathrm i a+bi,然后 DFT 一下 A ( x ) A(x) A(x) 再求个平方,得到 A 2 ( x ) A^2(x) A2(x),再 IDFT 回去。我们可以发现得到的系数都是 ( a + b i ) 2 = a 2 − b 2 + 2 a b i (a + b\mathrm i)^2 = a^2 - b^2 + 2ab\mathrm i (a+bi)2=a2b2+2abi 的形式的,所以只需要取出虚部再除以二就得到答案了。

这样的写法可以减小常数,跑的比 NTT 还快。

#include <cstdio>
#include <cctype>
#include <cmath>
#define FOR(i, a, b) for (int i = a; i <= b; ++i)

typedef double db;

const int maxn = 3e6 + 5;
const db pi = acos(-1.0);

inline int read()
{
    char c = getchar();
    int s = 0;
    while (!isdigit(c))
        c = getchar();
    while (isdigit(c))
        s = 10 * s + c - '0', c = getchar();
    return s;
}

template<typename T> inline void myswap(T &a, T &b)
{
    T t = a;
    a = b;
    b = t;
    return;
}

struct cmplx
{
    db x, y;
    cmplx(db xx = 0, db yy = 0)
    {
        x = xx, y = yy;
    }
} a[maxn];

int rev[maxn];

cmplx operator+(const cmplx &a, const cmplx &b) {return cmplx(a.x + b.x, a.y + b.y);}
cmplx operator-(const cmplx &a, const cmplx &b) {return cmplx(a.x - b.x, a.y - b.y);}
cmplx operator*(const cmplx &a, const cmplx &b) {return cmplx(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}

void DFT(cmplx *f, int lim, int type)
{
    FOR(i, 0, lim - 1)
        if (i < rev[i])
            myswap(f[i], f[rev[i]]);
    for (int p = 2; p <= lim; p <<= 1)
    {
        int len = p >> 1;
        cmplx Wp(cos(2 * pi / p), type * sin(2 * pi / p));
        for (int k = 0; k < lim; k += p)
        {
            cmplx w(1, 0);
            for (int l = k; l < k + len; ++l, w = w * Wp)
            {
                cmplx tmp = w * f[len + l];
                f[len + l] = f[l] - tmp;
                f[l] = f[l] + tmp;
            }
        }
    }
}

int main()
{
    int n = read(), m = read();
    FOR(i, 0, n) a[i].x = read();
    FOR(i, 0, m) a[i].y = read();
    int lim = 1;
    while (lim <= n + m) lim <<= 1;
    FOR(i, 0, lim - 1)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0));
    DFT(a, lim, 1);
    FOR(i, 0, lim - 1)
        a[i] = a[i] * a[i];
    DFT(a, lim, -1);
    FOR(i, 0, n + m)
        printf("%d ", (int)(a[i].y / lim / 2.0 + 0.5));
    return 0;
}

快速数论变换(NTT)

有了 FFT,我们已经有能力在 O ( n log ⁡ n ) O(n\log n) O(nlogn) 的时间内求出两个多项式的卷积了。但是 FFT 也有它的缺点:复数采用的浮点运算不仅造成精度的问题,还会增大常数。遗憾的是数学家们已经证明了 C \mathbb C C 中只有单位复数根满足 FFT 的要求。

考虑到利用多项式的计数题很多都是模意义下的,所以自然希望为单位复数根找一个模意义下的替代品。此时就进入下面的前置知识:原根。

原根

设整数 r , n r,n r,n 满足 r ⊥ n ∧ r ≠ 0 ∧ n > 0 r\perp n\land r \not= 0 \land n > 0 rnr=0n>0,使得 r x ≡ 1 ( m o d n ) r^x \equiv 1\pmod n rx1(modn)最小正整数 x x x 称为 r r r n n n,记为 o r d n r \mathrm{ord}_nr ordnr δ n ( r ) \delta_n(r) δn(r)

r , n ∈ N + ∧ r ⊥ n r,n\in\mathbb N^+\land r\perp n r,nN+rn,当 ord ⁡ n r = ϕ ( n ) \operatorname{ord}_nr = \phi(n) ordnr=ϕ(n) 时,称 r r r 是模 n n n 的原根或者 n n n 的原根。

NTT

对于质数 p = q n + 1   ( n = 2 m ) p = qn + 1\:(n = 2^m) p=qn+1(n=2m),原根 g g g 满足 g q n ≡ 1 ( m o d p ) g^{qn}\equiv 1\pmod p gqn1(modp),将 g n = g q ( m o d p ) g_n = g^q\pmod p gn=gq(modp) 看作 ω n \omega_n ωn 的等价,其满足相似的性质,比如 g n n ≡ 1 ( m o d p ) , g n n / 2 ≡ − 1 ( m o d p ) g_n^n\equiv 1\pmod p,g_n^{n/2} \equiv -1\pmod p gnn1(modp),gnn/21(modp)

常见的质数
p = 998244353 = 7 × 17 × 2 23 + 1 , g = 3 p = 1004535809 = 479 × 2 21 + 1 , g = 3 \begin{aligned} p &= 998244353 = 7\times17\times2^{23} + 1,&g = 3\\ p &= 1004535809 = 479\times 2^{21} + 1,&g = 3 \end{aligned} pp=998244353=7×17×223+1,=1004535809=479×221+1,g=3g=3
迭代到长度为 l l l 时, g l = g p − 1 l g_l = g^{\frac{p - 1}{l}} gl=glp1

直接看代码:

#include <cstdio>
#include <cctype>
#define FOR(i, a, b) for (int i = a; i <= b; ++i)

typedef long long ll;

const ll G = 3;
const ll mod = 998244353;
const int maxn = 3e6 + 5;

inline int read()
{
    char c = getchar();
    int s = 0;
    while (!isdigit(c))
        c = getchar();
    while (isdigit(c))
        s = 10 * s + c - '0', c = getchar();
    return s;
}

template<typename T> inline void myswap(T &a, T &b)
{
    T t = a;
    a = b;
    b = t;
    return;
}

ll pow(ll base, ll p = mod - 2)
{
    ll ret = 1;
    for (; p; p >>= 1)
    {
        if (p & 1)
            ret = ret * base % mod;
        base = base * base % mod;
    }
    return ret;
}

int rev[maxn];
ll f[maxn], g[maxn];
const ll invG = pow(G);

void NTT(ll *f, int lim, int type)
{
    FOR(i, 0, lim - 1)
        if (i < rev[i])
            myswap(f[i], f[rev[i]]);
    for (int p = 2; p <= lim; p <<= 1)
    {
        int len = p >> 1;
        ll tG = pow(type ? G : invG, (mod - 1) / p);
        for (int k = 0; k < lim; k += p)
        {
            ll buf = 1;
            for (int l = k; l < k + len; ++l, buf = buf * tG % mod)
            {
                ll tmp = buf * f[len + l] % mod;
                f[len + l] = f[l] - tmp;
                if (f[len + l] < 0) f[len + l] += mod;//及时取模
                f[l] = f[l] + tmp;
                if (f[l] > mod) f[l] -= mod;//及时取模
            }
        }
    }
    ll invlim = pow(lim);//最后还回去,除以lim相当于乘上lim的逆元
    if (!type)
        FOR(i, 0, lim - 1)
            f[i] = (f[i] * invlim % mod);
    return;
}

int main()
{
    int n = read(), m = read();
    FOR(i, 0, n) f[i] = read();
    FOR(i, 0, m) g[i] = read();
    int lim = 1;
    while (lim <= n + m) lim <<= 1;
    FOR(i, 0, lim - 1)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
    NTT(f, lim, 1), NTT(g, lim, 1);
    FOR(i, 0, lim - 1)
        f[i] = f[i] * g[i] % mod;
    NTT(f, lim, 0);
    FOR(i, 0, n + m)
        printf("%d ", (int)f[i]);
    return 0;
}

FFT/NTT 优化卷积的一些例子

在继续之前,我们先来看看 FFT/NTT 的一些应用。(高精度乘法就不说了,记得最后进位就可以了)

  • 优化一般的卷积
  • 和生成函数一起食用
  • 字符串匹配(你没看错)

洛谷 P3338 [ZJOI2014]力

题意:给定 { q } \{q\} {q},定义
F i = ∑ j = 1 i − 1 q i q j ( i − j ) 2 − ∑ j = i + 1 n q i q j ( i − j ) 2 F_i = \sum_{j = 1}^{i - 1}\frac{q_iq_j}{(i - j)^2} - \sum_{j = i + 1}^n\frac{q_iq_j}{(i - j)^2} Fi=j=1i1(ij)2qiqjj=i+1n(ij)2qiqj

E i = F i q i E_i=\frac{F_i}{q_i} Ei=qiFi
考虑暴力的话,这道题是 O ( n 2 ) O(n^2) O(n2) 的,过不去,考虑转化式子:
E i = F i q i = ∑ j = 1 i − 1 q j ( i − j ) 2 − ∑ j = i + 1 n q j ( i − j ) 2 \begin{aligned} E_i &= \frac{F_i}{q_i}\\ &=\sum_{j = 1}^{i - 1}\frac{q_j}{(i - j)^2} - \sum_{j = i + 1}^n\frac{q_j}{(i - j)^2}\\ \end{aligned} Ei=qiFi=j=1i1(ij)2qjj=i+1n(ij)2qj
我们尝试将其化为卷积的形式,令 f i = q i f_i = q_i fi=qi,且 f 0 = 0 f_0 = 0 f0=0 g i = 1 i 2 g_i =\dfrac{1}{i^2} gi=i21,且 g 0 = 0 g_0 = 0 g0=0,回代:
E i = ∑ j = 0 i f j g i − j − ∑ j = i n f j g j − i \begin{aligned} E_i &= \sum_{j = 0}^{i}f_jg_{i - j} - \sum_{j = i}^nf_jg_{j - i} \end{aligned} Ei=j=0ifjgijj=infjgji
左边的部分已经是一个卷积的形式了,考虑继续化简右边。此时我们可以使用一个翻转的技巧,令 f i ′ = f n − i f'_i = f_{n - i} fi=fni t = n − i t = n - i t=ni,则右半边的式子可以化为 ∑ j = 0 t f t − j ′ g j \displaystyle\sum_{j = 0}^{t}f'_{t - j}g_j j=0tftjgj。现在两边都化为卷积的形式了,可以愉快的使用 FFT 加速了。

即我们设多项式 A ( x ) = ∑ i = 0 n f i x n A(x) =\displaystyle\sum_{i = 0}^nf_ix^n A(x)=i=0nfixn B ( x ) = ∑ i = 0 n g i x n B(x) = \displaystyle\sum_{i = 0}^ng_ix^n B(x)=i=0ngixn C ( x ) = ∑ i = 0 n f i ′ C(x) = \displaystyle\sum_{i = 0}^nf'_i C(x)=i=0nfi。再令 L ( x ) = A ( x ) B ( x ) L(x) = A(x)B(x) L(x)=A(x)B(x) R ( x ) = B ( x ) C ( x ) R(x) = B(x)C(x) R(x)=B(x)C(x),不难发现答案 E i = l i − r n − i E_i = l_i - r_{n - i} Ei=lirni,其中 l i l_i li r i r_i ri 分别为 L ( x ) L(x) L(x) R ( x ) R(x) R(x) x i x^i xi 的系数。

int main()
{
    int n; scanf("%d", &n);
    FOR(i, 1, n)
    {
        scanf("%lf", &a[i].x);
        b[i].x = (1.0 / i / i);
        c[n - i].x = a[i].x;
    }
    int lim = 1;
    while (lim <= (n << 1)) lim <<= 1;
    FOR(i, 0, lim)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0));
    DFT(a, lim, 1), DFT(b, lim, 1), DFT(c, lim, 1);
    FOR(i, 0, lim)
        a[i] = a[i] * b[i], c[i] = b[i] * c[i];
    DFT(a, lim, -1), DFT(c, lim, -1);
    FOR(i, 1, n)
        printf("%.3lf\n", a[i].x - c[n - i].x);
    return 0;
}

洛谷 P3723 [AH2017/HNOI2017]礼物

题意:给定两个序列 { x } \{x\} {x} { y } \{y\} {y},可以整体平移序列或者整体加/减某个数,求最终序列

∑ i = 1 n ( x i − y i ) 2 \sum_{i = 1}^n(x_i - y_i)^2 i=1n(xiyi)2

的最小值。

分析:设整体加减的数为 c c c c c c 可正可负),我们需要最小化的就是下面这个式子:

∑ i = 1 n ( x i − y i + c ) 2 \sum_{i = 1}^n(x_i - y_i + c)^2 i=1n(xiyi+c)2

展开上面的式子,由 ( x i − y i + c ) 2 = x i 2 + y i 2 + c 2 + 2 x i c − 2 y i c − 2 x i y i (x_i - y_i +c)^2 = x_i^2 + y_i^2 + c^2 + 2x_ic - 2y_ic - 2x_iy_i (xiyi+c)2=xi2+yi2+c2+2xic2yic2xiyi 可以得到原式可化简为

∑ x i 2 + ∑ y i 2 + n c 2 + 2 c ∑ x i − 2 c ∑ y i − 2 ∑ x i y i \sum x_i^2 + \sum y_i^2 + nc^2 + 2c\sum x_i - 2c\sum y_i - 2\sum x_iy_i xi2+yi2+nc2+2cxi2cyi2xiyi

(下标省略)

不难发现我们只需要最大化 ∑ x i y i \sum x_iy_i xiyi 就可以啦。

f k f_k fk 为旋转了 k k k 个单位后 ∑ x i y i \sum x_iy_i xiyi 的取值,先把 { x } \{x\} {x} 倍长一波,则

f k = ∑ i = 1 n x i + k y k f_k = \sum_{i = 1}^nx_{i + k}y_k fk=i=1nxi+kyk

翻转 y y y

f k = ∑ i = 1 n x i + k y n − i + 1 ′ f_k = \sum_{i = 1}^nx_{i + k}y_{n - i + 1}' fk=i=1nxi+kyni+1

考虑多项式 f ( t ) = ∑ i = 1 n x i t i f(t) = \sum_{i = 1}^n x_it^i f(t)=i=1nxiti g ( t ) = ∑ i = 1 n y i t i g(t) = \sum_{i = 1}^n y_it^i g(t)=i=1nyiti,令 h ( t ) = f ( t ) ∗ g ( t ) h(t) = f(t) * g(t) h(t)=f(t)g(t),不难发现其 t n + k + 1 t^{n + k + 1} tn+k+1 的系数即为 f k f_k fk。因此可以使用 FFT/NTT 将倍长过的 { x } \{x\} {x} 与翻转过的 { y } \{y\} {y} 卷起来,然后把结果从第 n + 1 n + 1 n+1 到第 2 n 2n 2n 处找最值就可以了

NTT 的实现:

ll a[maxn], b[maxn];
ll suma, sumb, suma2, sumb2, n, m;
int rev[maxn];

void NTT(ll *f, int lim, int type)
{
    FOR(i, 0, lim - 1)
        if (i < rev[i])
            swap(f[i], f[rev[i]]);
    for (int p = 2; p <= lim; p <<= 1)
    {
        int len = p >> 1;
        ll Gp = pow(type ? G : invG, (mod - 1) / p);
        for (int k = 0; k < lim; k += p)
        {
            ll buf = 1;
            for (int l = k; l < k + len; ++l, buf = buf * Gp % mod)
            {
                ll tmp = buf * f[l + len] % mod;
                f[l + len] = f[l] - tmp;
                if (f[l + len] < 0) f[l + len] += mod;
                f[l] = f[l] + tmp;
                if (f[l] > mod) f[l] -= mod;
            }
        }
    }
    ll invlim = pow(lim);
    if (!type)
        FOR(i, 0, lim - 1)
            f[i] = f[i] * invlim % mod;
    return;
}

int main()
{
    n = read(), m = read();
    FOR(i, 1, n)
        a[i] = a[i + n] = read(), suma += a[i], suma2 += a[i] * a[i];
    FOR(i, 1, n)
        b[n - i + 1] = read(), sumb += b[n - i + 1], sumb2 += b[n - i + 1] * b[n - i + 1];
    int lim = 1;
    while (lim <= 3 * n) lim <<= 1;
    FOR(i, 0, lim - 1)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
    NTT(a, lim, 1), NTT(b, lim, 1);
    FOR(i, 0, lim)
        a[i] = a[i] * b[i] % mod;//千万不要忘记取模
    NTT(a, lim, 0);
    ll ans = 1e18;
    FOR(i, 1, n)
        FOR(j, -m, m)
            ans = min(ans, suma2 + sumb2 + j * j * n + 2 * j * (suma - sumb) - 2 * a[i + n]);
    printf("%lld\n", ans);
    return 0;
}

BZOJ3771 Triple

题意:有 n n n 把价值分别为 a i a_i ai 的斧子,河神可能拿走 1 - 3 把,问每种可能的损失价值及其对应方案数。(不计顺序)

思路:这是一道生成函数的入门题。

考虑设出多项式 A ( x ) A(x) A(x),其系数有 A [ a i ] = 1 A[a_i] = 1 A[ai]=1,代表选一把的。则你可能会觉得答案为 A ( x ) + A 2 ( x ) + A 3 ( x ) A(x) + A^2(x) + A^3(x) A(x)+A2(x)+A3(x)。但是这样是显然不对的。为什么?

因为这样的话同一个元素可能被选两次或三次,对于这种情况定义 B ( x ) B(x) B(x) C ( x ) C(x) C(x) 满足 B [ 2 a i ] = 1 B[2a_i] = 1 B[2ai]=1 C [ 3 a i ] = 1 C[3a_i] = 1 C[3ai]=1,代表同时选两次/三次的,减掉这些方案数就可以了。然后需要注意顺序问题:

选一把的答案为 A ( x ) A(x) A(x),不难发现选两种的即为 A 2 ( x ) − B ( x ) 2 \dfrac{A^2(x) - B(x)}{2} 2A2(x)B(x),选三种的比较麻烦:不能同时选两种一样的,即减去 3 A ( x ) B ( x ) 3A(x)B(x) 3A(x)B(x),,但是选三种同样的又会被多减两次,最后除以 3 ! 3! 3! 去掉顺序问题,所以最终答案为:

A ( x ) + A 2 ( x ) − B ( x ) 2 + A 3 ( x ) − 3 A ( x ) B ( x ) + 2 C ( x ) 6 A(x) + \frac{A^2(x) - B(x)}{2} + \frac{A^3(x) - 3A(x)B(x) + 2C(x)}{6} A(x)+2A2(x)B(x)+6A3(x)3A(x)B(x)+2C(x)

生成函数的卷积使用 NTT 或 FFT 优化即可。注意此时 NTT 模数要取一个更大的质数。不知道为什么生成函数能这样对应的可以意会一下多项式卷积的定义式以及这些系数的组合意义

ll f1[maxn], f2[maxn], f3[maxn], ans[maxn];
ll g[maxn], t[maxn];

int main()
{
    int n = read();
    while (n--)
    {
        int tmp = read();
        ++f1[tmp], ++g[tmp], ++ans[tmp];
        ++f2[tmp << 1], ++f3[tmp * 3];
    }
    int lim = 1;
    while (lim <= (40000 * 3 + 5)) lim <<= 1;
    FOR(i, 0, lim - 1)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
    NTT(f1, lim, 1), NTT(g, lim, 1);
    FOR(i, 0, lim - 1)
        g[i] = f1[i] * g[i] % mod;
    NTT(g, lim, 0);
    FOR(i, 0, lim - 1)
        ans[i] += (g[i] - f2[i]) / 2;
    NTT(g, lim, 1);
    FOR(i, 0, lim - 1)
        g[i] = f1[i] * g[i] % mod;
    NTT(g, lim, 0);
    NTT(f2, lim, 1);
    FOR(i, 0, lim - 1)
        f2[i] = f2[i] * f1[i] % mod;
    NTT(f2, lim, 0);
    FOR(i, 0, lim - 1)
    {
        ans[i] += (g[i] - 3 * f2[i] + 2 * f3[i]) / 6;
        if (ans[i]) printf("%d %lld\n", i, ans[i]);
    }
    return 0;
}

FFT/NTT 与字符串匹配

字符串下标从 1 1 1 开始

最一般的情况

考虑文本串 S S S 和模式串 T T T,串长 n = ∣ S ∣ n = |S| n=S m = ∣ T ∣ m = |T| m=T,保证 n ≥ m n \ge m nm,现在需要找出 T T T S S S 中出现的每个位置。直接跑 KMP 就可以了,但是这个不是要提的重点。考虑串 S S S 的第 i i i 个字符为 S [ i ] S[i] S[i],那么匹配就可以写成 S [ i ] − T [ j ] = 0 S[i] - T[j] = 0 S[i]T[j]=0,这个应该是比较好想的。

假设 T T T S S S 的第 i i i 位开始成功匹配,则我们有

∑ j = 1 m ( S [ i + j − 1 ] − T [ j ] ) 2 = 0 \sum_{j = 1}^{m} (S[i + j - 1] - T[j])^2 = 0 j=1m(S[i+j1]T[j])2=0

为了防止正负号相互抵消,所以需要平方。由于 i + j − 1 + j i + j - 1 + j i+j1+j 不是定值,不符合我们需要的卷积的形式,所以翻转一下 T T T 让其变为 T ′ T' T

∑ j = 1 m ( S [ i + j − 1 ] − T [ m − j + 1 ] ) 2 = 0 \sum_{j = 1}^{m} (S[i + j - 1] - T[m - j + 1])^2 = 0 j=1m(S[i+j1]T[mj+1])2=0

打开来我们就会发现

∑ j = 1 m ( S [ i + j − 1 ] 2 + T [ j ] 2 − 2 S [ i + j − 1 ] T [ m − j + 1 ] ) = 0 \sum_{j = 1}^{m} (S[i + j - 1]^2 + T[j]^2 - 2S[i + j - 1]T[m - j + 1]) = 0 j=1m(S[i+j1]2+T[j]22S[i+j1]T[mj+1])=0

i + j − 1 + m − j + 1 = i + m i + j - 1 + m - j + 1 = i + m i+j1+mj+1=i+m,为定值。

所以上面的式子就可以化成

∑ j = 1 m S [ i + j − 1 ] 2 + ∑ j = 1 m T [ j ] 2 − 2 ∑ x + y = i + m S [ x ] T [ y ] \sum_{j = 1}^m S[i + j - 1]^2 + \sum_{j = 1}^m T[j]^2 - 2\sum_{x + y = i + m}S[x]T[y] j=1mS[i+j1]2+j=1mT[j]22x+y=i+mS[x]T[y]

第一项直接前缀和就可以解决,第二项常数,第三项使用 FFT/NTT。

更加好理解地,设 f ( x ) = ∑ i + j = x + m S [ i ] T [ j ] f(x) = \sum_{i + j = x + m}S[i]T[j] f(x)=i+j=x+mS[i]T[j],我们只需要求出这个 f ( x ) f(x) f(x) 就可以了。

问题来了,这样难写复杂度高常数大全方位被 KMP 吊打的算法有什么存在的意义吗?对不起还真的有:

带通配符的字符串匹配

请看例题 洛谷 P4173 残缺的字符串。仍然是字符串匹配,但是每个串都有通配符,这个时候 KMP 就显得无能为力了。怎么办呢?好好思考一下两个字符如何才能匹配:

  • 两个字符完全一样
  • 其中至少一个为通配符

两者是逻辑或的关系,我们魔改一下上面的式子,不难发现我们只需要把通配符的值设为 0 0 0 就可以解决了:

定义匹配函数 F ( x ) F(x) F(x) 表示 S S S 的第 x x x 位开始和 T T T 是否匹配,匹配的话 F ( x ) = 0 F(x) = 0 F(x)=0

F ( x ) = ∑ j = 1 m ( S [ i + j − 1 ] − T [ j ] ) 2 S [ i + j − 1 ] T [ j ] F(x) = \sum_{j = 1}^m(S[i + j - 1] - T[j])^2S[i + j - 1]T[j] F(x)=j=1m(S[i+j1]T[j])2S[i+j1]T[j]

化简:

F ( x ) = ∑ j = 1 m ( S [ x + j − 1 ] − T [ j ] ) 2 S [ x + j − 1 ] T [ j ] = ∑ j = 1 m ( S [ x + j − 1 ] − T ′ [ m − j + 1 ] ) 2 S [ x + j − 1 ] T ′ [ m − j + 1 ] = ∑ j = 1 m ( S [ x + j − 1 ] 2 + T ′ [ m − j + 1 ] 2 − 2 S [ x + j − 1 ] T ′ [ m − j + 1 ] ) S [ x + j − 1 ] T ′ [ m − j + 1 ] = ∑ j = 1 m S [ x + j − 1 ] 3 T ′ [ m − j + 1 ] + ∑ j = 1 m S [ x + j − 1 ] T ′ [ m − j + 1 ] 3 − 2 ∑ j = 1 m S [ x + j − 1 ] 2 T ′ [ m − j + 1 ] 2 = ∑ i + j = x + m S [ i ] 3 T ′ [ j ] + ∑ i + j = x + m S [ i ] T ′ [ j ] 3 − 2 ∑ i + j = x + m S [ i ] 2 T [ j ] 2 \begin{aligned} F(x) &= \sum_{j = 1}^m(S[x + j - 1] - T[j])^2S[x + j - 1]T[j]\\ &= \sum_{j = 1}^m(S[x + j - 1] - T'[m - j + 1])^2S[x + j - 1]T'[m - j + 1]\\ &= \sum_{j = 1}^m(S[x + j - 1]^2 + T'[m - j + 1]^2 - 2S[x + j - 1]T'[m - j + 1])S[x + j - 1]T'[m - j + 1]\\ &= \sum_{j = 1}^m S[x + j - 1]^3T'[m - j + 1] + \sum_{j = 1}^m S[x + j - 1]T'[m - j + 1]^3- 2\sum_{j = 1}^m S[x + j - 1]^2T'[m - j + 1]^2\\ &= \sum_{i + j = x + m}S[i]^3T'[j] + \sum_{i + j = x + m}S[i]T'[j]^3 - 2\sum_{i + j = x + m}S[i]^2T[j]^2 \end{aligned} F(x)=j=1m(S[x+j1]T[j])2S[x+j1]T[j]=j=1m(S[x+j1]T[mj+1])2S[x+j1]T[mj+1]=j=1m(S[x+j1]2+T[mj+1]22S[x+j1]T[mj+1])S[x+j1]T[mj+1]=j=1mS[x+j1]3T[mj+1]+j=1mS[x+j1]T[mj+1]32j=1mS[x+j1]2T[mj+1]2=i+j=x+mS[i]3T[j]+i+j=x+mS[i]T[j]32i+j=x+mS[i]2T[j]2

于是问题就解决了,只需要用 NTT/FFT 计算出上面三项恶心的东西出来就 OK 了。一共进行 7 7 7 次 NTT 即可。

需要注意的是最后枚举答案的时候只能枚举到 n − m + 1 n - m + 1 nm+1 处,否则只有 35 35 35 分。

int f[maxn << 1], g[maxn << 1], f2[maxn << 1], g2[maxn << 1], f3[maxn << 1], g3[maxn << 1];
int ans[maxn << 1], vec[maxn << 1], tot;

int n, m;
char a[maxn], b[maxn];

int main()
{
    m = readInt(), n = readInt();
    scanf("%s", a + 1);
    scanf("%s", b + 1);
    FOR(i, 1, m)
    {
        g[i] = (a[m - i + 1] == '*') ? 0 : a[m - i + 1] - 'a' + 1;
        g2[i] = g[i] * g[i], g3[i] = g2[i] * g[i];
    }
    FOR(i, 1, n)
    {
        f[i] = (b[i] == '*') ? 0 : b[i] - 'a' + 1;
        f2[i] = f[i] * f[i], f3[i] = f2[i] * f[i];
    }
    int lim = 1;
    while (lim <= n + m) lim <<= 1;
    NTT(f, lim, 1), NTT(f2, lim, 1), NTT(f3, lim, 1);
    NTT(g, lim, 1), NTT(g2, lim, 1), NTT(g3, lim, 1);
    FOR(i, 0, lim - 1)
        ans[i] = (1ll * f3[i] * g[i] % mod + 1ll * f[i] * g3[i] % mod - 2ll * f2[i] * g2[i] % mod) % mod;
    NTT(ans, lim, 0);
    int cnt = 0;
    FOR(i, 1, n - m + 1)
        if (ans[i + m] == 0)
            ++cnt, vec[++tot] = i;
    printf("%d\n", cnt);
    FOR(i, 1, tot) printf("%d ", vec[i]);
    return 0;
}
另外一道例题

请看 CF528D Fuzzy Search。题意:字符串匹配, 1 ≤ ∣ T ∣ ≤ ∣ S ∣ ≤ 2 × 1 0 5 1\le |T| \le |S| \le 2\times 10^5 1TS2×105,字符集只有 ATCG \texttt{ATCG} ATCG T T T S S S 中的第 i i i 个位置出现当且仅当 ∀ j ∈ [ 1 , T ] \forall j\in [1,T] j[1,T] ∃ p \exist p p 使得 ∣ i + j − 1 − p ∣ ≤ k ∧ S [ p ] = T [ j ] |i + j - 1 - p|\le k \land S[p] = T[j] i+j1pkS[p]=T[j]。即偏移量不能超过 k k k

注意到字符集很小,只有 4 4 4 个字符,所以我们可以把字符串 01 化,分开考虑每个字母。比如 S = ATCGAA S = \texttt{ATCGAA} S=ATCGAA T = ACAA T = \texttt{ACAA} T=ACAA,现在只考虑字母 A \texttt A A,把 A \texttt A A 化成 1 1 1 而其他的化为 0 0 0,则 S = 100011 S = \texttt{100011} S=100011 T = 1011 T= \texttt{1011} T=1011。然后假设 k = 1 k = 1 k=1,把能扩展的都往两边扩展,则 S = 110111 S = \texttt{110111} S=110111

这个时候,我们就可以设匹配 F ( x , c ) F(x, c) F(x,c) 表示 T T T S S S 的第 x x x 位开始字符 c c c 能匹配的数量,最终答案为 A ( x ) = ∑ c ∈ { A,T,C,G } F ( x , c ) A(x) = \sum_{c\in\lbrace\texttt{A,T,C,G}\rbrace}F(x,c) A(x)=c{A,T,C,G}F(x,c) T T T 匹配成功当且仅当 A ( x ) = m A(x) = m A(x)=m,即所有字符都匹配到了,否则失败。

不难发现 F ( x , c ) = ∑ j = 1 m S [ x + j − 1 ] T [ j ] F(x, c) = \sum_{j = 1}^mS[x + j - 1]T[j] F(x,c)=j=1mS[x+j1]T[j],老套路翻转一下变为 F ( x , c ) = ∑ j = 1 m S [ x + j − 1 ] T [ m − j + 1 ] = ∑ i + j = x + m S [ i ] T [ j ] F(x, c) = \sum_{j = 1}^m S[x + j - 1]T[m - j + 1] = \sum_{i + j = x + m}S[i]T[j] F(x,c)=j=1mS[x+j1]T[mj+1]=i+j=x+mS[i]T[j]。这就是喜闻乐见的卷积形式了。NTT 直接上就完了。

int f[maxn << 1], g[maxn << 1];
int ans[maxn << 1];

int n, m, k, lim = 1;
char s[maxn << 1], t[maxn << 1];

void proc(char c)
{
    clr(f, lim), clr(g, lim);
    for (int i = 1, lst = -1e9; i <= n; ++i)
    {
        if (s[i] == c)
            lst = i;
        if (i - lst <= k)
            f[i] = 1;
    }
    for (int i = n, lst = 1e9; i; --i)
    {
        if (s[i] == c)
            lst = i;
        if (lst - i <= k)
            f[i] = 1;
    }
    FOR(i, 1, m)
        g[i] = (t[m - i + 1] == c);
    NTT(f, lim, 1), NTT(g, lim, 1);
    FOR(i, 0, lim - 1)
        f[i] = 1ll * f[i] * g[i] % mod;
    NTT(f, lim, 0);
    FOR(i, 1, n)
        ans[i] += f[i + m];
}

int main()
{
    n = readInt(), m = readInt(), k = readInt();
    scanf("%s", s + 1);
    scanf("%s", t + 1);
    while (lim <= n + m) lim <<= 1;
    FOR(i, 0, 3) proc("ATCG"[i]);
    int cnt = 0;
    FOR(i, 1, n)
        cnt += (ans[i] == m);
    printf("%d\n", cnt);
    return 0;
}

关于封装

以后的全家桶会大量使用 NTT 等基础操作,考虑实现一个常数较小的封装:

首先是各种 #define,由于我们在进行多项式运算的时候需要考虑界的问题,因此一定要把超过界了的给清零,不然可能出现各种奇奇怪怪的问题比如多卷了之类的:

#define ll long long
#define ull unsigned ll
#define FOR(i, a, b) for (int i = a; i <= b; ++i)
#define clr(f, n) memset(f, 0, (sizeof(int)) * (n))
#define cpy(f, g, n) memcpy(f, g, (sizeof(int)) * (n))

memsetmemcpy 的用法建议自己去查。

接下来是一些基本的东西:

const ll G = 3, mod = 998244353;
const int maxn = ((1 << 21) + 500);

ll qpow(ll base, ll p = mod - 2)
{
    ll ret = 1;
    for (; p; p >>= 1)
    {
        if (p & 1)
            ret = ret * base % mod;
        base = base * base % mod;
    }
    return ret;
}

const ll invG = qpow(G);

没什么说的

NTT 时需要用到的位逆序置换:

int rev[maxn << 1], revlim;

void get_rev(int lim)
{
    if (lim == revlim) return;
    revlim = lim;
    FOR(i, 0, lim - 1)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
    return;
}

这样子可以在需要进行很多次 NTT 的时候智能的求出对应需要的 rev

NTT 和乘法的封装:关于 static 关键字相关的建议自己查一下。

void NTT(int *g, int n, int type)
{
    get_rev(n);
    static ull f[maxn << 1], w[maxn];
    w[0] = 1;
    FOR(i, 0, n - 1)
        f[i] = (((long long)mod << 5ll) + g[rev[i]]) % mod;//防止负数带来影响
    for (int l = 1; l < n; l <<= 1)
    {
        ull tmp = qpow(type ? G : invG, (mod - 1) / (l << 1));
        FOR(i, 1, l - 1) w[i] = w[i - 1] * tmp % mod;//预处理“单位根”
        for (int i = 0; i < n; i += (l << 1))
        {
            for (int j = 0; j < l; ++j)
            {
                ll tt = w[j] * f[i + j + l] % mod;
                f[i + j + l] = f[i + j] + mod - tt;
                f[i + j] += tt;
            }
        }
        if (l == (1 << 10))
            FOR(i, 0, n - 1) f[i] %= mod;
    }
    if (!type)
    {
        ull invn  = qpow(n);
        FOR(i, 0, n - 1)
            g[i] = f[i] % mod * invn % mod;
    }
    else FOR(i, 0, n - 1)
        g[i] = f[i] % mod;
    return;
}

void times(int *f, int *g, int len, int lim)//len 表示两个多项式的最高次数,lim 为最终需要的项数
{
    static int sav[maxn << 1];//临时变量
    int n = 1;
    while (n < (len << 1)) n <<= 1;
    clr(sav, n), cpy(sav, g, n);
    NTT(f, n, 1); NTT(sav, n, 1);
    FOR(i, 0, n - 1)
        f[i] = 1ll * f[i] * sav[i] % mod;
    NTT(f, n, 0);
    clr(f + lim, n - lim), clr(sav, n);//把界以上的部分清掉,把 sav 清干净
    return;
}

有了如上的封装,我们在写 P3803 时主函数里面就可以帅气的写:

int main()
{
    n = readInt(), m = readInt();
    FOR(i, 0, n) f[i] = readInt();
    FOR(i, 0, m) g[i] = readInt();
    times(f, g, max(m, n), m + n + 1);
    FOR(i, 0, m + n) printf("%d ", f[i]);
    return 0;
}

就做完了。

多项式乘法逆

定义

当两个多项式 F ( x ) F(x) F(x) G ( x ) G(x) G(x) 在每一项系数模 p p p 时有 F ( x ) ∗ G ( x ) ≡ 1 ( m o d x n ) F(x)*G(x)\equiv 1\pmod{x^n} F(x)G(x)1(modxn) 时,称 F ( x ) F(x) F(x) G ( x ) G(x) G(x) 互为乘法逆元。此处 ( m o d x n ) \pmod{x^n} (modxn) 代表次数高于 n n n 的项都不考虑。

需要的前置知识:NTT

求法

考虑倍增。假设我们要求满足 F ( x ) ∗ G ( x ) ≡ 1 ( m o d x k ) F(x)*G(x)\equiv 1\pmod{x^k} F(x)G(x)1(modxk) G ( x ) G(x) G(x),并且已经求出了满足 F ( x ) ∗ G ′ ( x ) ≡ 1 ( m o d x ⌈ x 2 ⌉ ) F(x)*G'(x)\equiv1\pmod{x^{\lceil\frac x 2\rceil}} F(x)G(x)1(modx2x) G ‘ ( x ) G‘(x) G(x)。那么我们由
F ( x ) ∗ G ′ ( x ) ≡ 1 ( m o d x ⌈ x 2 ⌉ ) F(x)*G'(x)\equiv1\pmod{x^{\lceil\frac x 2\rceil}} F(x)G(x)1(modx2x)
必然可以推出
F ( x ) ∗ G ( x ) ≡ 1 ( m o d x ⌈ x 2 ⌉ ) F(x)*G(x)\equiv1\pmod{x^{\lceil\frac x 2\rceil}} F(x)G(x)1(modx2x)
那么
G ( x ) ≡ G ′ ( x ) ( m o d x ⌈ x 2 ⌉ ) G(x) \equiv G'(x)\pmod{x^{\lceil\frac x 2\rceil}} G(x)G(x)(modx2x)
作差,
G ( x ) − G ′ ( x ) ≡ 0 ( m o d x ⌈ x 2 ⌉ ) G(x) - G'(x)\equiv 0\pmod{x^{\lceil\frac x2\rceil}} G(x)G(x)0(modx2x)
将两边同时平方,不难发现界会从 ⌈ n 2 ⌉ \lceil\frac n2\rceil 2n 变为 n n n
G 2 ( x ) − 2 G ( x ) G ′ ( x ) + G ′ 2 ( x ) ≡ 0 ( m o d x n ) G^2(x) - 2G(x)G'(x) + G'^2(x)\equiv 0\pmod{x^n} G2(x)2G(x)G(x)+G2(x)0(modxn)
现在要求的是 G ( x ) G(x) G(x),而我们发现 G 2 ( x ) G^2(x) G2(x) 不好处理,怎么办?同时乘以 F ( x ) F(x) F(x) 就可以消掉一个 G ( x ) G(x) G(x)
G ( x ) − 2 G ′ ( x ) + G ′ 2 ( x ) F ( x ) ≡ 0 ( m o d x n ) G(x) - 2G'(x) + G'^2(x)F(x)\equiv 0\pmod{x^n} G(x)2G(x)+G2(x)F(x)0(modxn)
所以我们得到了
G ( x ) ≡ 2 G ′ ( x ) − G ′ 2 ( x ) F ( x ) ( m o d x n ) G(x)\equiv 2G'(x) - G'^2(x)F(x)\pmod{x^n} G(x)2G(x)G2(x)F(x)(modxn)
根据这个,我们就可以从 G ′ ( x ) G'(x) G(x) 推出 G ( x ) G(x) G(x) 的值出来了。从上往下递归求解,到常数项的时候直接费马小定理求逆元然后一步步回溯上去。不难发现复杂度为 T ( n ) = T ( n / 2 ) + O ( n log ⁡ n ) T(n) = T(n/2) + O(n\log n) T(n)=T(n/2)+O(nlogn),由主定理知总复杂度为 O ( n log ⁡ n ) O(n\log n) O(nlogn)

实现

递归版多项式求逆:

int tmp[maxn << 1];//这是临时数组

void invpoly(int *f, int *ans, int m)
{
    if (m == 1)
        return ans[0] = qpow(f[0]), void();//常数项就直接返回
    invpoly(f, ans, (m + 1) >> 1);//先递归求出 m / 2 向上取整的情况
    int n = 1;
    while (n < (m << 1)) n <<= 1;
    cpy(tmp, f, m), clr(tmp + m, n - m);//把 f 数组的前 n 项都复制进 tmp 里面并把 tmp 高于 m 的地方全部清零
    NTT(ans, n, 1), NTT(tmp, n, 1);
    FOR(i, 0, n - 1)
        ans[i] = 1ll * (2ll - 1ll * ans[i] * tmp[i] % mod + mod) % mod * ans[i] % mod;//根据公式计算
    NTT(ans, n, 0);
    clr(ans + m, n - m);//高于 m 的舍弃
    return;
}

int f[maxn << 1], ans[maxn << 1];

int n, m;

int main()
{
    n = readInt();
    FOR(i, 0, n - 1) f[i] = readInt();
    invpoly(f, ans, n);
    FOR(i, 0, n - 1) printf("%d ", ans[i]);
    return 0;
}

当然我们也可以不递归,考虑递推实现(其实快不了多少的)。递推无非就是从 1 1 1 开始倍增向上走。

void invpoly(int *f, int m)
{
    int n;
    for (n = 1; n < m; n <<= 1);
    static int w[maxn << 1], r[maxn << 1], sav[maxn << 1];
    w[0] = qpow(f[0]);
    for (int len = 2; len <= n; len <<= 1)//len 代表当前的界
    {
        FOR(i, 0, (len >> 1) - 1)
            r[i] = (w[i] << 1) % mod;//处理 2G'(x)
        cpy(sav, f, len);
        NTT(w, len << 1, 1);
        FOR(i, 0, (len << 1) - 1)
            w[i] = 1ll * w[i] * w[i] % mod;
        NTT(sav, len << 1, 1);
        FOR(i, 0, (len << 1) - 1)
            w[i] = 1ll * w[i] * sav[i] % mod;
        NTT(w, len << 1, 0);
        clr(w + len, len);
        FOR(i, 0, len - 1)
            w[i] = (r[i] - w[i] + mod) % mod;
    }
    cpy(f, w, m);//把答案还回到 f 数组里面
    clr(sav, n << 1);//清零
    clr(w, n << 1);//清零
    clr(r, n << 1);//清零
    return;
}

请注意在递推进行乘法的时候一定要把空间开成 len << 1,为什么呢?计算 G ′ 2 ( x ) F ( x ) G'^2(x)F(x) G2(x)F(x) 的时候, G ′ 2 ( x ) G'^2(x) G2(x)len 项, F ( x ) F(x) F(x) 也含 len 项,卷起来就是 len << 1 项了,如果不开够的话 NTT 把点值还原回去的时候是会出问题的。

多项式的导数/积分

回顾一下一些基础的导数公式:
f ( x ) = e x    ⟹    f ′ ( x ) = e x f ( x ) = ln ⁡ x    ⟹    f ′ ( x ) = 1 x f ( x ) = a x k    ⟹    f ′ ( x ) = a k x k − 1 ( f ( x ) ± g ( x ) ) ′ = f ′ ( x ) ± g ′ ( x ) ( f ( g ( x ) ) ) ′ = f ′ ( g ( x ) ) × g ′ ( x ) \begin{aligned} f(x) = e^x&\implies f'(x) = e^x\\ f(x) = \ln x&\implies f'(x) = \frac1x\\ f(x) = ax^k&\implies f'(x) = akx^{k - 1}\\ (f(x)\pm g(x))'&= f'(x)\pm g'(x)\\ (f(g(x)))' &=f'(g(x))\times g'(x) \end{aligned} f(x)=exf(x)=lnxf(x)=axk(f(x)±g(x))(f(g(x)))f(x)=exf(x)=x1f(x)=akxk1=f(x)±g(x)=f(g(x))×g(x)
我们在这里定义一下多项式的求导:
f ( x ) = ∑ i = 0 n a i x i    ⟹    f ′ ( x ) = ∑ i = 0 n − 1 ( i + 1 ) a i + 1 x i f(x) = \sum_{i = 0}^na_ix^i\implies f'(x) = \sum_{i = 0}^{n - 1}(i + 1)a_{i + 1}x^i f(x)=i=0naixif(x)=i=0n1(i+1)ai+1xi
积分为求导的逆运算:
f ( x ) = ∑ i = 0 n a i x i    ⟹    ∫ f ( x ) d x = ∑ i = 1 n + 1 a i − 1 x i i f(x) = \sum_{i = 0}^na_ix_i\implies\int f(x) \mathrm dx= \sum_{i = 1}^{n + 1}\frac{a_{i - 1}x^i}{i} f(x)=i=0naixif(x)dx=i=1n+1iai1xi
所以多项式求导和求积分的代码就很容易写出来了,当然需要一开始线性预处理一下逆元。

void derivate(int *f, int m)
{
    FOR(i, 1, m - 1)
        f[i - 1] = 1ll * f[i] * i % mod;
    f[m - 1] = 0;
    return;
}

int inv[maxn];

void initinv(int lim)
{
    inv[1] = 1;
    FOR(i, 2, lim)
        inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod;
    return;
}

void intergrate(int *f, int m)
{
    DEC(i, m, 1)
        f[i] = 1ll * f[i - 1] * inv[i] % mod;
    f[0] = 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值