我们写出斧头的生成函数
F(x)
题目要求用1把、2把、3把斧头能拼出的方案数,不考虑顺序
那就要去掉非法情况和重复情况
所以就不能写成:
F(x)+F2(x)+F3(x)
对于
F2(x)
,他会有一把斧头用2次的情况
对于
F3(x)
,会有一把斧头用2、3次的情况
于是令
T(x)
为每把斧头用2次的生成函数,
G(x)
为每把斧头用3次的生成函数
答案函数为
F(x)1!+F2(x)−T(x)2!+F3(x)−3F(x)T(x)+2G(x)3!
FFT加速卷积
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<vector>
#include<string>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
#define inf 1e15
using namespace std;
const int maxn = 310000;
const double pi = acos(-1);
int n,N,lg,u;
struct E
{
double x,y;
E(){}
E(const double _x,const double _y){x=_x;y=_y;}
}s1[maxn],s2[maxn],s3[maxn],w[maxn],a2[maxn],a3[maxn]; int id[maxn];
inline E operator +(const E &x,const E &y){return E(x.x+y.x,x.y+y.y);}
inline E operator -(const E &x,const E &y){return E(x.x-y.x,x.y-y.y);}
inline E operator *(const E &x,const E &y){return E(x.x*y.x-x.y*y.y,x.x*y.y+x.y*y.x);}
void pre()
{
u*=3;
lg=0,N=1; for(;N<=u;N<<=1,lg++);
for(int i=0;i<N;i++) id[i]=(id[i>>1]>>1)|((i&1)<<lg-1);
for(int m=2;m<=N;m<<=1)
{
int t=m>>1,tt=N/m;
for(int i=0;i<t;i++)
{
w[i*tt]=E(cos(2*pi*i/m),sin(2*pi*i/m));
w[N-i*tt]=E(cos(2*pi*i/m),sin(-2*pi*i/m));
}
}
}
void DFT(E *s,const int sig)
{
for(int i=0;i<N;i++) if(i<id[i])
swap(s[i],s[id[i]]);
for(int m=2;m<=N;m<<=1)
{
int t=m>>1,tt=N/m;
for(int i=0;i<t;i++)
{
E wn=sig==1?w[i*tt]:w[N-i*tt];
for(int j=i;j<N;j+=m)
{
E tx=s[j],ty=s[j+t]*wn;
s[j]=tx+ty;
s[j+t]=tx-ty;
}
}
}
if(sig==-1) for(int i=0;i<N;i++) s[i].x/=(double)N;
}
int ans[maxn];
int main()
{
scanf("%d",&n); u=0;
for(int i=1;i<=n;i++)
{
int x; scanf("%d",&x); ans[x]++;
if(x>u) u=x;
s1[x].x+=1.0;
s2[x*2].x+=1.0;
s3[x*3].x+=1.0;
}
pre();
DFT(s1,1); DFT(s2,1); DFT(s3,1);
for(int i=0;i<N;i++)
{
a2[i]=s1[i]*s1[i]; a2[i]=a2[i]-s2[i];
E t1=s1[i]*s1[i]; t1=t1*s1[i];
E t2=s1[i]*s2[i]; t2.x*=3.0,t2.y*=3.0;
E t3=s3[i]; t3.x*=2.0,t3.y*=2.0;
a3[i]=t1-t2; a3[i]=a3[i]+t3;
}
DFT(a2,-1); DFT(a3,-1);
for(int i=0;i<N;i++)
{
ans[i]+=(int)((a2[i].x+0.5)/2.0)+(int)((a3[i].x+0.5)/6.0);
if(ans[i]) printf("%d %d\n",i,ans[i]);
}
return 0;
}