首先说一下我用FFT做什么,我要做的是多项式乘法,或者说,加速多项式乘法。
考虑多项式
A(x)=∑j=0n−1ajxj
,它一共有
n
项,我们称它的次数界为
上面提到的多项式表示方法
A(x)=∑j=0n−1ajxj
称为系数表示,实际上它还有另一种表示方法叫点值表示。我们取
n
个不同的值
假设我们可以迅速在多项式的系数表示和点值表示间转换,就可以迅速完成多项式乘法。回到那个多项式
下面放出一道题:Thief in a Shop
可以这样理解题意:有
n
个数
于是,我们就需要FFT来迅速完成多项式乘法,同时利用倍增,只进行
log2(k)
次乘法。注意每次乘完之后,要对那个列向量规整一下,避免迭代过程累积误差。
下面的代码可以作为模板。。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const double eps = 0.5;
const double PI = acos(-1.0);
struct Complex{
double r,i;
Complex(double r=0.0,double i=0.0):r(r),i(i){
}
Complex operator+(const Complex& c)const{
return Complex(r+c.r,i+c.i);
}
Complex operator-(const Complex& c)const{
return Complex(r-c.r,i-c.i);
}
Complex operator*(const Complex& c)const{
return Complex(r*c.r-i*c.i,r*c.i+i*c.r);
}
};
void change(Complex y[],int len){
for(int i=1,j=len>>1;i<len-1;i++){
if(i<j){
swap(y[i],y[j]);
}
int k = len>>1;
while(j>=k){
j -= k;
k >>= 1;
}
if(j<k){
j += k;
}
}
}
void fft(Complex y[],int len,int on){
change(y,len);
for(int i=2;i<=len;i<<=1){
Complex wn(cos(-on*2*PI/i),sin(-on*2*PI/i));
for(int j=0;j<len;j+=i){
Complex w(1,0);
for(int k=j;k<j+i/2;k++){
Complex u = y[k];
Complex t = w*y[k+i/2];
y[k] = u + t;
y[k+i/2] = u - t;
w = w * wn;
}
}
}
if(on == -1){
for(int i=0;i<len;i++){
y[i].r /= len;
}
}
}
int lowbit(int x){
return x&(-x);
}
int fix(Complex *y,int l){
while(l && y[l-1].r<eps ){
l--;
}
for(int i=0;i<l;i++){
Complex &c = y[i];
if(c.r>eps){
c.r = 1;
}else{
c.r = 0;
}
c.i = 0;
}
for(int i=l;i<1024*1024;i++){
Complex &c = y[i];
c.r = c.i = 0;
}
return l+1;
}
void Print(Complex *y,int l){
for(int i=0;i<l;i++){
if(y[i].r>eps){
cout<<i<<" ";
}
}
cout<<endl;
}
int mul(Complex *v1,int l1,Complex *v2,int l2,Complex *res){
l1 = fix(v1,l1);
l2 = fix(v2,l2);
int sz = 2*max(l1,l2);
while(sz!=lowbit(sz)){
sz+=lowbit(sz);
}
l1 = l2 = sz;
fft(v1,l1,1);
fft(v2,l2,1);
for(int i=0;i<sz;i++){
res[i] = (v1[i]*v2[i]);
}
fft(res,sz,-1);
return sz;
}
Complex v[1024*1024];
Complex v2[1024*1024];
Complex ans[1024*1024];
int main(){
int n,k;
cin>>n>>k;
for(int i=1;i<=n;i++){
int num;
cin>>num;
v[num].r = 1.0;
}
int sz = 1024;
ans[0].r = 1;
while(k){
if(k&1){
for(int i=0;i<sz;i++){
v2[i] = v[i];
}
mul(v2,sz,ans,sz,ans);
}
for(int i=0;i<sz;i++){
v2[i] = v[i];
}
sz = mul(v,sz,v2,sz,v);
k>>=1;
}
sz = fix(ans,sz);
for(int i=0;i<sz;i++){
if(ans[i].r > eps){
printf("%d ",i);
}
}
return 0;
}