题解:
观察下这个等式,其实就是
f
n
=
∑
i
f
n
−
a
i
f_n = \sum_{i}f_{n-a_i}
fn=∑ifn−ai。
前面的求逆预处理一下,后面的特征多项式倍增和取模就行了,时间复杂度 O ( k log k log m ) O(k \log k \log m) O(klogklogm)。
注意 a 1 a_1 a1可能大于23333,RE了半天。。
#include <bits/stdc++.h>
using namespace std;
const int RLEN=1<<18|1;
inline char nc() {
static char ibuf[RLEN],*ib,*ob;
(ib==ob) && (ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
return (ib==ob) ? -1 : *ib++;
}
inline int rd() {
char ch=nc(); int i=0,f=1;
while(!isdigit(ch)) {if(ch=='-')f=-1; ch=nc();}
while(isdigit(ch)) {i=(i<<1)+(i<<3)+ch-'0'; ch=nc();}
return i*f;
}
const int N=2e6+50, mod=104857601, lim=23333;
inline int add(int x,int y) {return (x+y>=mod) ? (x+y-mod) : (x+y);}
inline int dec(int x,int y) {return (x-y<0) ? (x-y+mod) : (x-y);}
inline int mul(int x,int y) {return (long long)x*y%mod;}
inline int power(int a,int b,int rs=1) {for(;b;b>>=1,a=mul(a,a)) if(b&1) rs=mul(rs,a); return rs;}
namespace FFT {
int w[N*8],pos[N*8],k;
int A[N*8],B[N*8];
inline void pre() {
for(int bl=1,i=0;bl<=2e5;i+=bl,bl<<=1) {
int wn=power(3,(mod-1)/bl/2);
w[i]=1;
for(int j=1;j<bl;j++) w[i+j]=mul(w[i+j-1],wn);
}
}
inline void init(int nn) {
for(k=1;k<=nn;k<<=1);
for(int i=1;i<k;i++) pos[i]=(i&1) ? ((pos[i>>1]>>1)^(k>>1)) : (pos[i>>1]>>1);
memset(A,0,sizeof(int)*k);
memset(B,0,sizeof(int)*k);
}
inline void dft(int *a) {
for(int i=1;i<k;i++) if(pos[i]>i) swap(a[pos[i]],a[i]);
for(int bl=1,i=0;bl<k;i+=bl,bl<<=1) {
int tl=bl<<1, wn=power(3,(mod-1)/tl);
for(int bg=0;bg<k;bg+=tl)
for(int j=0;j<bl;j++) {
int &t1=a[bg+j], &t2=a[bg+j+bl], t=mul(t2,w[i+j]);
t2=dec(t1,t); t1=add(t1,t);
}
}
}
inline void func() {
dft(A); dft(B);
for(int i=0;i<k;i++) B[i]=mul(A[i],B[i]);
dft(B); reverse(B+1,B+k);
const int inv=power(k,mod-2);
for(int i=0;i<k;i++) B[i]=mul(B[i],inv);
}
}
int n,a[N],A,B; long long m;
struct poly {
vector <int> a;
inline int deg() const {return a.size()-1;}
poly(int d=0,int t=0) {a.resize(d+1); a[d]=t;}
inline int& operator [](const int &b) {return a[b];}
inline const int& operator [](const int &b) const {return a[b];}
friend inline poly operator *(const poly &a,const poly &b) {
FFT::init(a.deg()+b.deg());
for(int i=0;i<=a.deg();i++) FFT::A[i]=a[i];
for(int i=0;i<=b.deg();i++) FFT::B[i]=b[i];
FFT::func();
poly c(a.deg()+b.deg(),0);
for(int i=0;i<=c.deg();i++) c[i]=FFT::B[i];
return c;
}
friend inline poly operator *(const poly &a,const int &b) {
poly c=a;
for(int i=0;i<=c.deg();i++) c.a[i]=mul(c.a[i],b);
return c;
}
friend inline poly operator -(const poly &a,const poly &b) {
poly c(max(a.deg(),b.deg()));
for(int i=0;i<=a.deg();i++) c[i]=a[i];
for(int i=0;i<=b.deg();i++) c[i]=dec(c[i],b[i]);
return c;
}
friend inline poly operator +(const poly &a,const poly &b) {
poly c(max(a.deg(),b.deg()));
for(int i=0;i<=a.deg();i++) c[i]=a[i];
for(int i=0;i<=b.deg();i++) c[i]=add(c[i],b[i]);
return c;
}
inline poly calc_inv(poly f,int len) {
if(len==1) return poly(0,power(f[0],mod-2));
poly f0=calc_inv(f.extend(len>>1),len>>1);
return (f0*2-(f0*f0).extend(len-1)*f).extend(len-1);
}
inline poly calc_inverse(int nn) {
FFT::init(nn);
return calc_inv(this->extend(nn),FFT::k).extend(nn);
}
inline poly rev() {poly c=*this; reverse(c.a.begin(),c.a.end()); return c;}
friend inline poly operator %(const poly &a,const poly &b) {
if(a.deg()<b.deg()) return a;
poly A=a, B=b;
int res=a.deg()-b.deg();
A=A.rev().extend(res);
B=B.rev().extend(res);
poly C=B.calc_inverse(res);
poly D=(A*C).extend(res).rev();
return (a-b*D).extend(b.deg()-1);
}
inline void dg() {
if(!deg()) a[0]=0;
else {
for(int i=0;i<deg();i++) a[i]=mul(i+1,a[i+1]);
a.pop_back();
}
}
inline poly extend(int nn) {
poly c=*this;
c.a.resize(nn+1);
return c;
}
inline int getval(int x) {
int res=0, w=1;
for(int i=0;i<=deg();i++)
res=add(res,mul(w,a[i])), w=mul(w,x);
return res;
}
};
int main() {
FFT::pre();
cin>>n>>m;
cin>>a[1]>>A>>B;
for(int i=2;i<=n;i++) a[i]=((long long)a[i-1]*A+B)%lim+1;
poly g(lim,0);
for(int i=1;i<=n;i++) if(a[i]<=lim) g[a[i]]++;
poly f(0,1);
f=(f-g).calc_inverse(lim);
g=poly(lim,1);
for(int i=1;i<=n;i++) if(a[i]<=lim) g[lim-a[i]]--;
poly rs(0,1), h(1,1);
for(;m;m>>=1,h=h*h%g)
if(m&1) rs=rs*h%g;
int ans=0;
for(int i=0;i<=lim && i<=rs.deg();i++)
ans=add(ans,mul(rs[i],f[i]));
cout<<ans<<'\n';
}