题目描述及题解
题解就到这位大佬的博客上看吧。。说得很清楚。。
然而菜爆了的我还是调了3个小时。。。所以来说说实现细节。。
- nlogn预处理单位根 ω \omega ω,递推乘的话精度会炸
- mod是1e9+7,常规的FFT会炸long long,需要拆系数:
多项式A和B相乘,把 a i , b i a_i,b_i ai,bi拆成 k M + p kM+p kM+p的形式( M = m o d M=\sqrt{mod} M=mod)
再分别做 k [ a i ] 与 k [ b i ] , k [ a i ] 与 p [ b i ] , p [ a i ] 与 k [ b i ] , p [ a i ] 与 p [ b i ] k[a_i]与k[b_i],k[a_i]与p[b_i],p[a_i]与k[b_i],p[a_i]与p[b_i] k[ai]与k[bi],k[ai]与p[bi],p[ai]与k[bi],p[ai]与p[bi]的卷积,然后合并答案 - static 是静态数组,函数里的static a=1在第二次调用时不会重新赋值。。。
#include<cstdio>
#include<cmath>
#include<algorithm>
#define LL long long
#define maxn 120005
using namespace std;
const double Pi = acos(-1);
const int mod = 1e9+7, M1 = 31623, M2 = 14122;
struct complex
{
double r,i;
complex(double _r=0,double _i=0):r(_r),i(_i){}
complex operator + (const complex &t)const{return complex(r+t.r,i+t.i);}
complex operator - (const complex &t)const{return complex(r-t.r,i-t.i);}
complex operator * (const complex &t)const{return complex(r*t.r-i*t.i,r*t.i+i*t.r);}
complex conj(){return complex(r,-i);}
}w[16][maxn/2];
void change(complex *a,int len)
{
for(int i=1,j=len/2,k;i<len-1;i++)
{
if(i<j) swap(a[i],a[j]);
for(k=len/2;j>=k;j-=k,k>>=1);
j+=k;
}
}
int m,len,f[maxn],g[maxn],p2[maxn];
LL n,fac[maxn],inv[maxn];
inline void fft(complex *a,int flg)
{
change(a,len);
for(int i=2,o=0;i<=len;i<<=1,o++)
for(int j=0;j<len;j+=i)
for(int k=j;k<j+i/2;k++)
{
complex u=a[k],v=(flg==1?w[o][k-j]:w[o][k-j].conj())*a[k+i/2];
a[k]=u+v,a[k+i/2]=u-v;
}
if(flg==-1) for(int i=0;i<len;i++) a[i].r/=len;
}
void calc(int *A,int *B,int *ret)
{
static complex sta[2][2][maxn];
for(int i=0;i<len;i++)
if(i<=m)
{
sta[0][0][i]=A[i]/M1,sta[0][1][i]=A[i]%M1;
sta[1][0][i]=B[i]/M1,sta[1][1][i]=B[i]%M1;
}
else sta[0][0][i]=sta[0][1][i]=sta[1][0][i]=sta[1][1][i]=0;
fft(sta[0][0],1),fft(sta[0][1],1),fft(sta[1][0],1),fft(sta[1][1],1);
static complex rt[3][maxn];
for(int i=0;i<len;i++)
{
rt[0][i]=sta[0][1][i]*sta[1][1][i];
rt[1][i]=sta[0][0][i]*sta[1][1][i]+sta[0][1][i]*sta[1][0][i];
rt[2][i]=sta[0][0][i]*sta[1][0][i];
}
fft(rt[0],-1),fft(rt[1],-1),fft(rt[2],-1);
for(int i=0;i<len;i++) ret[i]=(llround(rt[0][i].r)%mod+llround(rt[1][i].r)%mod*M1%mod+llround(rt[2][i].r)*M2%mod)%mod;
}
void solve(int *A,int *B,int cnt,int *ret)
{
static int tmp[2][maxn];
int sp2=1;
for(int i=0;i<len;i++,sp2=1ll*sp2*p2[cnt]%mod)
tmp[0][i]=1ll*A[i]*sp2%mod,tmp[1][i]=B[i];
calc(tmp[0],tmp[1],ret);
}
void FAC_INV(int N)
{
fac[0]=fac[1]=inv[0]=inv[1]=p2[0]=1,p2[1]=2;
for(int i=2;i<=N;i++) fac[i]=fac[i-1]*i%mod,inv[i]=(mod-mod/i)*inv[mod%i]%mod,p2[i]=p2[i-1]*2%mod;
for(int i=2;i<=N;i++) inv[i]=inv[i]*inv[i-1]%mod;
}
int main()
{
scanf("%lld%d",&n,&m);
if(n>m) return puts("0"),0;
FAC_INV(m);
len=1;while(len<2*m+1) len<<=1;
for(int i=2,k=0;i<=len;i<<=1,k++)
for(int j=0;j<i/2;j++) w[k][j]=complex(cos(2*Pi*j/i),sin(2*Pi*j/i));
for(int i=1;i<=m;i++) g[i]=inv[i];
f[0]=1;
int cnt=1,ans=0;
for(;n;n>>=1,solve(g,g,cnt,g),cnt<<=1) if(n&1) solve(f,g,cnt,f);
for(int i=1;i<=m;i++) ans=(ans+f[i]*fac[m]%mod*inv[m-i]%mod)%mod;
printf("%d",ans);
}