最近终于能干自己想干的事情了
问题引入:
快速求多项式乘法(卷积)
引理:
n个不同点的点集可以确定唯一 的n次多项式
记相乘的两个多项式为
A
,
B
A,B
A,B
A
,
B
A,B
A,B的卷积记为
C
C
C
solution1
O ( n 2 ) O(n ^ 2) O(n2)的多项式乘法
solution2
根据引理,可以任意取 A , B A,B A,B上 ∣ A ∣ + ∣ B ∣ |A| + |B| ∣A∣+∣B∣个点
然后把 其纵坐标相乘,故在通过这 ∣ A ∣ + ∣ B ∣ |A| + |B| ∣A∣+∣B∣个点来求出C
复杂度 O ( n 3 ) [ 高 斯 消 去 ] O(n ^ 3)[高斯消去] O(n3)[高斯消去] / O ( n 2 ) [ 拉 格 朗 日 插 值 ] O(n ^ 2)[拉格朗日插值] O(n2)[拉格朗日插值]/…
solution3
瓶颈:通过点值表达式来求出 C C C
优化:每次可以选择一些横坐标不同,可以用速算点加速
复杂度 O ( n l o g n ) O(nlogn) O(nlogn)
实现具体如下:
重要概念
令 W ( n , i ) W(n , i) W(n,i)为将单位圆平分成n分的第i个点 ( 0 < = i < n ) (0 <= i < n) (0<=i<n)
W ( 2 ∗ n , 2 * i ) = W ( n , i ) W(2 *n , 2*i) = W(n , i) W(2∗n,2*i) = W(n,i)
W ( n , i + j ) = W ( n , i ) ∗ W ( n , j ) W(n , i + j) = W(n , i) * W(n , j) W(n,i+j) = W(n,i)∗W(n,j)
W ( n , i + n / 2 ) = − W ( n , i ) W(n , i + n/2) = -W(n , i) W(n,i+n/2)=−W(n,i)
重要公式
e i x = c o s x + i s i n x e^{ix} = cosx + isinx eix=cosx+isinx
联系 W ( n , i ) W(n , i) W(n,i),可以得到:
W ( n , i ) = c o s ( 2 ∗ P I ∗ i / n ) + i s i n ( 2 ∗ P I ∗ i / n ) W(n , i) = cos(2 * PI * i / n) + isin(2 * PI * i / n) W(n,i)=cos(2∗PI∗i/n)+isin(2∗PI∗i/n)
正戏开始
A ( w n k ) = ∑ i = 0 n − 1 a i ∗ W n k + i ) A(w ^ k_n) = \sum_{i = 0}^{n - 1}a_i*W^{k + i}_n) A(wnk)=∑i=0n−1ai∗Wnk+i)
B ( w n k ) = ∑ i = 0 n − 1 2 a i ∗ 2 ∗ W n k ∗ i ∗ 2 ) B(w ^ k_n) = \sum_{i = 0}^{ \frac{n-1}{2} \ }a_{i * 2}*W^{k *i * 2}_n) B(wnk)=∑i=02n−1 ai∗2∗Wnk∗i∗2)
C ( w n k ) = ∑ i = 0 n − 1 2 a i ∗ 2 + 1 ∗ W n k ( i ∗ 2 + 1 ) ) = C(w ^ k_n) = \sum_{i = 0}^{ \frac{n-1}{2} \ }a_{i * 2 + 1}*W^{k (i * 2 + 1)}_n)= C(wnk)=∑i=02n−1 ai∗2+1∗Wnk(i∗2+1))=
W n k ∗ ∑ i = 0 n − 1 2 a i ∗ 2 + 1 ∗ W n k ∗ i ∗ 2 ) W^{k}_n * \sum_{i = 0}^{ \frac{n-1}{2} \ }a_{i * 2 + 1}*W^{k * i * 2}_n) Wnk∗∑i=02n−1 ai∗2+1∗Wnk∗i∗2)
A ( w n k ) = B ( w n k ) + C ( W n k ) A(w^k_n) = B(w^k_n) + C(W^k_n) A(wnk)=B(wnk)+C(Wnk)
同理
A ( w n k + n 2 ) = ∑ i = 0 n − 1 a i ∗ W n k + i ) A(w ^ {k + \frac{n}{2}}_n) = \sum_{i = 0}^{n - 1}a_i*W^{k + i}_n) A(wnk+2n)=∑i=0n−1ai∗Wnk+i)
B ( w n k + n 2 ) = ∑ i = 0 n − 1 2 a i ∗ 2 ∗ W n ( k + n 2 ) ∗ i ∗ 2 ) B(w ^ {k + \frac{n}{2}}_n) = \sum_{i = 0}^{ \frac{n-1}{2} \ }a_{i * 2}*W^{(k + \frac{n}{2}) *i * 2}_n) B(wnk+2n)=∑i=02n−1 ai∗2∗Wn(k+2n)∗i∗2)
C ( w n k + n 2 ) = ∑ i = 0 n − 1 2 a i ∗ 2 + 1 ∗ W n k ( k + n 2 ) ( i ∗ 2 + 1 ) ) = C(w ^ {k + \frac{n}{2}}_n) = \sum_{i = 0}^{ \frac{n-1}{2} \ }a_{i * 2 + 1}*W^{k(k + \frac{n}{2}) (i * 2 + 1)}_n)= C(wnk+2n)=∑i=02n−1 ai∗2+1∗Wnk(k+2n)(i∗2+1))=
( − 1 ) ∗ W n k ∗ ∑ i = 0 n − 1 2 a i ∗ 2 + 1 ∗ W n ( k + n 2 ) + i ∗ 2 ) (-1)* W^{k}_n * \sum_{i = 0}^{ \frac{n-1}{2} \ }a_{i * 2 + 1}*W^{(k + \frac{n}{2}) + i * 2}_n) (−1)∗Wnk∗∑i=02n−1 ai∗2+1∗Wn(k+2n)+i∗2)
综上
A ( x ) = B ( x ) + C ( x ) A(x) = B(x) + C(x) A(x)=B(x)+C(x)
A ( w n k ) = B ( w n 2 k ) + C ( w n 2 k ) A(w^k_n) = B(w^{2k}_n) + C(w^{2k}_n) A(wnk)=B(wn2k)+C(wn2k)
A ( w n k + n 2 ) = B ( w n k + n 2 ) + C ( w n k + n 2 ) A(w ^ {k + \frac{n}{2}}_n) = B(w ^ {k + \frac{n}{2}}_n) + C(w ^ {k + \frac{n}{2}}_n) A(wnk+2n)=B(wnk+2n)+C(wnk+2n)
人话:
就是根据 w n i w_n^i wni的循环性质,一次求出每个 x x x取 w n i w^i_n wni的时候整个多项式的值,求出点值表达式,就在 O ( n l o g n ) O(nlogn) O(nlogn)的时间解决了这个问题
点值表达式转回原表达式
直接每次把 W n i W_n^i Wni取倒,而后做一次fft,且每一项都 ∗ n − 1 *n^{-1} ∗n−1
具体如下
#include<bits/stdc++.h>
#define MAXN 4000005
using namespace std;
int n,m,len,limit,rev[MAXN];
const double PI = acos(-1.0);
struct node{double shi,xu;}a[MAXN],b[MAXN];
node operator + (node x , node y){return (node){x.shi + y.shi , x.xu + y.xu};}
node operator - (node x , node y){return (node){x.shi - y.shi , x.xu - y.xu};}
node operator * (node x , node y){return (node){x.shi * y.shi - x.xu * y.xu , x.shi * y.xu + x.xu * y.shi};}
node dw(double x , double y , int h){//W(x , y)
double zz = y / x;
zz = 2.0 * zz * PI;
return (node){cos(zz) , sin(zz) * h};
}
void fft(node f[] , int tp){
for(int i = 0 ; i < len ; i++)if(i < rev[i])swap(f[i] , f[rev[i]]);
for(int i = 1 ; i < len ; i = i * 2){
for(int j = 0 ; j < len ; j = j + (i * 2)){
node xs = dw(i * 2 , 1 , tp) , W = (node){1 , 0};
for(int k = 0 ; k < i ; k++ , W = W * xs){
node p = f[k + j] , pp = f[k + j + i] * W;
f[k + j] = p + pp;
f[k + j + i] = p - pp;
}
}
}
}
int main(){
cin>>n>>m , n++ , m++;
for(int i = 0 ; i < n ; i++)cin>>a[i].shi;
for(int i = 0 ; i < m ; i++)cin>>b[i].shi;
while((1 << limit) <= n + m)limit++;
len = (1 << limit);
for(int i = 0 ; i < len ; i++){
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (limit - 1)));
}
fft(a , 1) , fft(b , 1);
for(int i = 0 ; i < len ; i++)a[i] = a[i] * b[i];
fft(a , -1);
for(int i = 0 ; i <= (n + m - 2) ; i++)cout<<int((a[i].shi / (len * 1.0) + 0.5))<<" ";
}