题目传送门
停更已久的blog。。。
orz ckw大佬
我们考虑构造
f
f
f的递推式。我们设有数组
c
c
c,满足
f
[
n
]
=
∑
i
=
1
k
c
[
i
]
×
f
[
n
−
i
]
f[n]=\sum_{i=1}^{k}c[i]\times f[n-i]
f[n]=i=1∑kc[i]×f[n−i]
我们令
c
[
0
]
=
−
1
c[0]=-1
c[0]=−1,则
∑
i
=
0
k
c
[
i
]
×
f
[
n
−
i
]
=
0
\sum_{i=0}^{k}c[i]\times f[n-i]=0
i=0∑kc[i]×f[n−i]=0
带入一下
∑
i
=
0
k
c
[
i
]
×
∑
j
=
1
k
a
[
j
]
v
[
j
]
n
−
i
=
0
\sum_{i=0}^{k}c[i]\times \sum_{j=1}^{k}a[j]v[j]^{n-i}=0
i=0∑kc[i]×j=1∑ka[j]v[j]n−i=0
我们考虑去构造一个
c
c
c,满足这条式子。
我们发现如果任意
j
j
j都满足
∑
i
=
0
k
c
[
i
]
a
[
j
]
v
[
j
]
n
−
i
=
0
\sum_{i=0}^{k}c[i]a[j]v[j]^{n-i}=0
i=0∑kc[i]a[j]v[j]n−i=0则一定满足上面的式子。
再变一下
∑
i
=
0
k
c
[
i
]
v
[
j
]
n
−
i
=
0
\sum_{i=0}^{k}c[i]v[j]^{n-i}=0
i=0∑kc[i]v[j]n−i=0
不难发现仍然满足。
我们再设有多项式
F
(
x
)
=
∑
i
=
0
k
c
[
k
−
i
]
x
i
F(x)=\sum_{i=0}^{k}c[k-i]x^{i}
F(x)=i=0∑kc[k−i]xi
我们可以构造
F
(
x
)
=
−
∏
i
=
1
k
(
x
−
v
[
i
]
)
F(x)=-\prod_{i=1}^{k}(x-v[i])
F(x)=−i=1∏k(x−v[i])
不难发现,根据这个多项式得出的
c
c
c还是满足上面的式子的233。
于是就可以用分治FFT求出
F
(
x
)
F(x)
F(x)的各项系数,也就求出了
c
c
c数组。
接下来递推一下就好了。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=270005,mod=1004535809;
int n,k,len,v[N],f[N],res[N],rev[N],x[N],y[N];
int add(int a,int b){
a+=b;
return a<mod?a:a-mod;
}
int cut(int a,int b){
a-=b;
return a>=0?a:a+mod;
}
int fastpow(int a,int x){
int res=1;
while(x){
if(x&1){
res=1LL*res*a%mod;
}
x>>=1;
a=1LL*a*a%mod;
}
return res;
}
void ntt(int a[],int dft,int n){
for(int i=0;i<n;i++){
rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
if(i<rev[i]){
swap(a[i],a[rev[i]]);
}
}
for(int i=1;i<n;i<<=1){
int wn=fastpow(3,(mod-1)/i/2);
if(dft==-1){
wn=fastpow(wn,mod-2);
}
for(int j=0;j<n;j+=(i<<1)){
int w=1,x,y;
for(int k=j;k<j+i;k++,w=1LL*w*wn%mod){
x=a[k];
y=1LL*w*a[k+i]%mod;
a[k]=add(x,y);
a[k+i]=cut(x,y);
}
}
}
if(dft==-1){
int inv=fastpow(n,mod-2);
for(int i=0;i<n;i++){
a[i]=1LL*a[i]*inv%mod;
}
}
}
void mul(int a[],int b[],int c[],int n,int m){
int len;
for(len=1;len<=n+m;len<<=1);
for(int i=0;i<n;i++){
x[i]=a[i];
}
for(int i=n;i<len;i++){
x[i]=0;
}
for(int i=0;i<m;i++){
y[i]=b[i];
}
for(int i=m;i<len;i++){
y[i]=0;
}
ntt(x,1,len);
ntt(y,1,len);
for(int i=0;i<len;i++){
x[i]=1LL*x[i]*y[i]%mod;
}
ntt(x,-1,len);
for(int i=0;i<n+m-1;i++){
c[i]=x[i];
}
}
void solve(int l,int r,int res[],int &len){
if(l==r){
res[0]=v[l]?mod-v[l]:0;
res[1]=1;
len=2;
return;
}
int mid=(l+r)/2;
int *lres=new int [(mid-l+1)<<1],*rres=new int [(r-mid)<<1],*llen=new int,*rlen=new int;
solve(l,mid,lres,*llen);
solve(mid+1,r,rres,*rlen);
mul(lres,rres,res,*llen,*rlen);
len=*llen+*rlen-1;
delete [] lres;
delete [] rres;
delete llen;
delete rlen;
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<=k;i++){
scanf("%d",&v[i]);
}
for(int i=1;i<=k;i++){
scanf("%d",&f[i]);
}
solve(1,k,res,len);
for(int i=0;i<=k;i++){
res[i]=res[i]?mod-res[i]:0;
}
reverse(res,res+k+1);
for(int i=k+1;i<=n;i++){
for(int j=1;j<=k;j++){
f[i]=add(f[i],1LL*res[j]*f[i-j]%mod);
}
}
printf("%d\n",f[n]);
return 0;
}