链接:http://acm.hdu.edu.cn/showproblem.php?pid=4656
脑子抽了做叉姐200题,一道题上午续到晚上。。
这道题不仅式子贼难推(推了几个小时弃疗),推完还不能直接上fft,要用什么任意模数fft(代码又调几个小时)。。
part1:推式子
----转载自https://blog.csdn.net/whzzt/article/details/70880091(Orz)
个人认为比较高妙的两步是将枚举顺序交换,然后先处理掉一部分化简式子,还有一步就是(k-j)^2那个吊炸天的操作了(推了上午搞不出来就是因为c^2kj这个玄学东西)。
part2:任意模数fft
之后因为ai在10^6,而且模数不是ntt模数,所以要用什么任意模数fft,强行学了一波,其实是一种解决子问题在合并的思想,将模数拆到根号级别,得到任意一个数可表示为a*k+b形式,把两个要进行卷积的数组中每个数都拆成a,b,一共四个数组,都做一次DFT,然后发现实际上(a1*k+b1)*(a2*k+b2)==a1*a2*k*k+(a1*b2+a2*b1)*k+b1*b1,k是一个常数提到fft过程完成后乘,所以这样按乘k的次数分3组最后合并,就能保证fft时数的大小转变为根号n级,不会爆精度了。。
代码:
#include<bits/stdc++.h>
#define ll long long
#define db double
using namespace std;
const ll mod=1e6+3,mo=sqrt(mod);
const int N=5e5+10,M=2e5+10;
const db pi=acos(-1);
struct cp{
db r,i;
cp(){r=0,i=0;}
cp(db x,db y):r(x),i(y){}
}omg[N],mp[4][N];
cp operator +(const cp x,const cp y)
{return cp(x.r+y.r,x.i+y.i);}
cp operator -(const cp x,const cp y)
{return cp(x.r-y.r,x.i-y.i);}
cp operator *(const cp x,const cp y)
{return cp(x.r*y.r-x.i*y.i,x.r*y.i+x.i*y.r);}
int n,wh[N],lim,cnt,tp1[N],tp2[N],tax1[N],tax2[N],tax3[N];
int jc[M],inv[mod+10],ijc[M],bb[M],dd[M],idd[M];
int b,c,d,a[M];
ll p[N],ans[N];
ll qpow(ll x,ll y)
{
ll res=1;
while(y)
{
if(y&1)res=1LL*res*x%mod;
x=1LL*x*x%mod,y>>=1;
}
return res;
}
void init()
{
jc[0]=1,ijc[0]=1,inv[1]=1;
for(int i=2;i<mod;i++)inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod;
for(int i=1;i<=n;i++)
jc[i]=1LL*jc[i-1]*i%mod,ijc[i]=1LL*ijc[i-1]*inv[i]%mod;
bb[0]=dd[0]=idd[0]=1;
for(int i=1;i<=n;i++)
{
bb[i]=1LL*bb[i-1]*b%mod;
dd[i]=1LL*dd[i-1]*d%mod;
idd[i]=1LL*idd[i-1]*inv[d]%mod;
}
for(int i=1;i<=lim;i++)
wh[i]=(wh[i>>1]>>1)|((i&1)<<(cnt-1));
for(int i=0;i<lim;i++)
omg[i]=cp(cos(2*pi*i/lim),sin(2*pi*i/lim));
omg[lim]=omg[0];
}
void fft(cp *a,bool inv)
{
int mid;cp t;
for(int i=0;i<lim;i++)
if(i<wh[i])swap(a[i],a[wh[i]]);
for(int l=2;l<=lim;l<<=1)
{
mid=l>>1;
for(int i=0;i<lim;i+=l)
{
for(int j=0;j<mid;j++)
{
t=a[i+j+mid]*(inv?omg[lim-lim/l*j]:omg[lim/l*j]);
a[i+j+mid]=a[i+j]-t;
a[i+j]=a[i+j]+t;
}
}
}
}
void mtt(int *x,int *y)
{
static cp t1,t2,t3;
for(int i=0;i<lim;i++)
{
mp[0][i]=cp(x[i]/mo,0),mp[1][i]=cp(x[i]%mo,0);
mp[2][i]=cp(y[i]/mo,0),mp[3][i]=cp(y[i]%mo,0);
}
fft(mp[0],0),fft(mp[1],0),fft(mp[2],0),fft(mp[3],0);
for(int i=0;i<lim;i++)
{
t1=mp[0][i]*mp[2][i],t2=mp[1][i]*mp[2][i]+mp[3][i]*mp[0][i];
t3=mp[1][i]*mp[3][i],mp[0][i]=t1,mp[1][i]=t2,mp[2][i]=t3;
}
fft(mp[0],1),fft(mp[1],1),fft(mp[2],1);
for(int i=0;i<lim;i++)
{
mp[0][i].r/=lim,mp[1][i].r/=lim,mp[2][i].r/=lim;
x[i]=(((ll)(mp[0][i].r+0.5)%mod*mo*mo%mod)+((ll)(mp[1][i].r+0.5)%mod*mo%mod)+((ll)(mp[2][i].r+0.5)%mod))%mod;
}
}
main(void)
{
scanf("%d",&n);
scanf("%d%d%d",&b,&c,&d);b%=mod,c%=mod,d%=mod;
for(int i=0;i<n;i++)
scanf("%d",&a[i]),a[i]%=mod;
lim=1,cnt=0;
while(lim<=2*n)lim<<=1,cnt++;
init();
for(int i=0;i<n;i++)
{
tp1[i]=1LL*dd[i]*jc[i]*a[i]%mod;
tp2[i]=ijc[i];
}
for(int i=0;i<n;i++)
if(i<n-i)swap(tp1[i],tp1[n-i]);
mtt(tp1,tp2);
for(int i=0;i<n;i++)p[i]=tp1[n-i];
for(int i=0;i<=lim;i++)
tp1[i]=tp2[i]=0;
tp2[n]=1;
for(ll i=0;i<n;i++)
{
tp1[i]=1LL*bb[i]*qpow(c,i*i)*p[i]%mod*idd[i]*ijc[i]%mod;
tp2[n+i]=inv[qpow(c,i*i)],tp2[i]=inv[qpow(c,(n-i)*(n-i))];
}
mtt(tp1,tp2);
for(int i=0;i<n;i++)ans[i]=tp1[n+i],ans[i]=1LL*ans[i]*qpow(c,1LL*i*i)%mod;
for(int i=0;i<n;i++)
printf("%lld\n",ans[i]);
}
upd:
两个多项式A,B要做DFT,可设多项式P,Q,令Pi = Ai + i * Bi,Qi = Ai - i * Bi,则可以发现
Q(ωk)= conj(P(ω−k)) = conj(P(ωL−k))
只要对P做DFT就可以同时知道Q,然后P和Q相加减得到A,B做DFT结果。
IDFT逆运算,设M(ωk)=A(ωk)+i∗B(ωk), 对M做IDFT,然后取出实部虚部的值即是A,B的结果。
优化到4次fft的mtt:
#include<bits/stdc++.h>
#define pii pair<int,int>
#define fi first
#define sc second
#define pb push_back
#define ll long long
#define trav(v,x) for(auto v:x)
#define all(x) (x).begin(), (x).end()
#define VI vector<int>
#define VLL vector<ll>
#define db double
using namespace std;
const int N = 1e6 + 100;
const int M = 32767;
const db pi = acos(-1);
const ll mod = 998244352;
struct cp{
db r, i;
cp(double r = 0, double i = 0) : r(r), i(i){}
cp operator * (const cp &a) {return cp(r * a.r - i * a.i, r * a.i + i * a.r);}
cp operator + (const cp &a) {return cp(r + a.r, i + a.i);}
cp operator - (const cp &a) {return cp(r - a.r, i - a.i);}
}w[N], A[N], B[N], AA[N], BB[N];
int len, cc, wh[N];
cp conj(cp a)
{return cp(a.r, -a.i);}
void fft(cp *a, bool inv)
{
cp tmp;
for(int i = 0; i < len; i++)
if(i < wh[i])swap(a[i], a[wh[i]]);
for(int l = 2; l <= len; l <<= 1)
{
int mid = l >> 1;
for(int i = 0; i < len; i += l)
{
for(int j = 0; j < mid; j++)
{
tmp = a[i + j + mid] * (inv ? w[len - len / l * j] : w[len / l * j]);
a[i + j + mid] = a[i + j] - tmp;
a[i + j] = a[i + j] + tmp;
}
}
}
}
VI mul(VI &a, VI &b) // mtt with M = 32767
{
VI res;
len = 1, cc = 0;
while(len < a.size() + b.size())
len <<= 1, ++cc;
for(int i = 1; i <= len; i++)
wh[i] = (wh[i >> 1] >> 1) | ((i & 1) << (cc - 1));
for(int i = 0; i <= len; i++)
w[i] = cp(cos(2.0 * pi * i / len), sin(2.0 * pi * i / len));
int sz = a.size() + b.size() - 1;
a.resize(len), b.resize(len);
for(int i = 0; i < len; i++)
{
A[i] = cp(a[i] / M, a[i] % M);
B[i] = cp(b[i] / M, b[i] % M);
}
fft(A, 0), fft(B, 0);
for(int i = 0; i < len; i++)
{
cp aa, bb, cc, dd;
int j = (len - i) % len;
aa = (A[i] + conj(A[j])) * cp(0.5, 0);
bb = (A[i] - conj(A[j])) * cp(0, -0.5);
cc = (B[i] + conj(B[j])) * cp(0.5, 0);
dd = (B[i] - conj(B[j])) * cp(0, -0.5);
AA[i] = aa * cc + aa * dd * cp(0, 1);
BB[i] = bb * dd + bb * cc * cp(0, 1);
}
fft(AA, 1), fft(BB, 1);
res.resize(sz);
for(int i = 0; i < sz; i++)
{
ll ac, ad, bc, bd;
ac = (ll)(AA[i].r / len + 0.5) % mod;
ad = (ll)(AA[i].i / len + 0.5) % mod;
bd = (ll)(BB[i].r / len + 0.5) % mod;
bc = (ll)(BB[i].i / len + 0.5) % mod;
res[i] = (ac * M * M + (ad + bc) * M + bd) % mod;
}
return res;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
return 0;
}