这题我之前用的是背包,现在补上fft的做法。题意是给你n种商品,每种商品数量无限,有各自的价格,现在需要买k个商品,设所得到的总花费为w,问你所有可能的w的值。
思路:类似于母函数,共有k个多项式相乘,最后我们只需要观察x的幂次所对应的系数是否大于0.k比较大因此需要快速幂+fft.
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#include<stack>
#include<cmath>
const double PI=acos(-1.0);
using namespace std;
struct complex
{
double r,i;
complex(double _r = 0.0,double _i = 0.0)
{
r = _r; i = _i;
}
complex operator +(const complex &b)
{
return complex(r+b.r,i+b.i);
}
complex operator -(const complex &b)
{
return complex(r-b.r,i-b.i);
}
complex operator *(const complex &b)
{
return complex(r*b.r-i*b.i,r*b.i+i*b.r);
}
};
void change(complex y[],int len)// 将数组的长度调整成2的整数次幂
{
int i,j,k;
for(i = 1, j = len/2;i < len-1; i++)
{
if(i < j)swap(y[i],y[j]);
k = len/2;
while( j >= k)
{
j -= k;
k /= 2;
}
if(j < k) j += k;
}
}
void fft(complex y[],int len,int on)//傅里叶变化求出点值
{
change(y,len);
for(int h=2;h<=len;h<<=1)
{
complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
for(int j=0;j<len;j+=h)
{
complex w(1,0);
for(int k=j;k<j+h/2;k++)
{
complex u=y[k];
complex t=w*y[k+h/2];//旋转因子
y[k]=u+t;
y[k+h/2]=u-t;
w=w*wn;
}
}
}
if(on==-1)
for(int i=0;i<len;i++)
y[i].r/=len;
}
const int maxn=2e6+5;
complex x1[maxn],x2[maxn];
int a[maxn],b[maxn];
void cal(int *a,int *b,int &lena,int &lenb)//快速实现多项式相乘,将结果保存在b数组中
{
int len=1;
while(len<lena+lenb)len<<=1;
for(int i=0;i<=lenb;i++)
x1[i]=complex(b[i],0);
for(int i=lenb+1;i<len;i++)
x1[i]=complex(0,0);
for(int i=0;i<=lena;i++)
x2[i]=complex(a[i],0);
for(int i=lena+1;i<len;i++)
x2[i]=complex(0,0);
fft(x1,len,1);
fft(x2,len,1);
for(int i=0;i<len;i++)
x1[i]=x1[i]*x2[i];//点值相乘
fft(x1,len,-1);//傅里叶逆变换求出系数
for(int i=0;i<=lena+lenb;i++)
b[i]=(int)(x1[i].r+0.5);
for(int i=0;i<=lena+lenb;i++)
if(b[i]>0)b[i]=1;//表示i这个价格可以达到
lenb+=lena;
}
int main()
{
//freopen("a.txt","r",stdin);
int n,k,x;
while(~scanf("%d%d",&n,&k))
{
memset(a,0,sizeof(a));
for(int i=0;i<n;i++)
{
scanf("%d",&x);
a[x]++;//a数组保存达到某种价格的方法数
}
b[0]=1;
int lena=1005,lenb=0;
while(k)//快速幂
{
if(k&1)cal(a,b,lena,lenb);//a,b数组快速幂更新,最终答案保存在b数组
if(k>1)cal(a,a,lena,lena);
k>>=1;
}
for(int i=0;i<=lena+lenb;i++)
{
if(b[i])printf("%d ",i);
}
printf("\n");
}
return 0;
}