容易列出dp方程
f[i][j]
f
[
i
]
[
j
]
表示第i个位置乘积为
j
j
的方案数。
那么推出转移
此时复杂度
O(n∗m2)
O
(
n
∗
m
2
)
显然不能接受
想到对于
n
n
特别大的情况一般都会用快速幂。
那么可以列出一个m*m的矩阵进行转移,复杂度
明显还是不能接受,发现瓶颈在于m.
考虑利用fft把
m2
m
2
优化掉
如何优化,发现因为
j∗p%m
j
∗
p
%
m
导致没法直接递推
那么利用原根
g
g
的性质,求出的原根。
已知
gi
g
i
互不相同,令
i=g[i]
i
=
g
[
i
]
那么原方程可写作
f[i+1][gj∗gp%m]+=f[i][j]→f[i+1][gj+p%m]+=f[i][j]
f
[
i
+
1
]
[
g
j
∗
g
p
%
m
]
+
=
f
[
i
]
[
j
]
→
f
[
i
+
1
]
[
g
j
+
p
%
m
]
+
=
f
[
i
]
[
j
]
那么这样显然可以用fft优化
复杂度
O(logn∗logm∗m)
O
(
log
n
∗
log
m
∗
m
)
c++代码如下:
#include <bits/stdc++.h>
#define rep(i,x,y) for(register int i = x; i <= y; ++ i)
#define repd(i,x,y) for(register int i = x ; i >= y; -- i)
using namespace std;
typedef long long ll;
template<typename T>inline void read(T&x)
{
x = 0;char c;int sign = 1;
do { c = getchar(); if(c == '-') sign = -1; }while(!isdigit(c));
do { x = x * 10 + c - '0'; c = getchar(); }while(isdigit(c));
x *= sign;
}
const ll N = 2e4+50,mod = 1004535809,G = 3;
ll n,m,x,g,L,len,inv,S,s[N],ind[N];
ll R[N],a[N],b[N],c[N],d[N];
inline ll quick_pow(ll x,ll y,ll p)
{
ll ans = 1;
while(y)
{
if(y&1) ans = ans * x % p;
x = x * x % p;
y >>= 1;
}
return ans;
}
inline void get_g(ll m)
{
rep(i,1,m-1)
{
int j = 1;
while(j < m) { if(quick_pow(i,j,m) == 1) break; ++j; }
if(j == m - 1)
{
g = i;
break;
}
}
}
inline void ntt(ll*a,ll f)
{
rep(i,0,len-1) if(i < R[i]) swap(a[i],a[R[i]]);
for(register int i = 1 ;i < len; i <<= 1)
{
ll wn = quick_pow(G,(mod - 1)/(i << 1),mod);
if(f == -1) wn = quick_pow(wn,mod - 2,mod);
for(register int j = 0;j < len; j += i << 1)
{
ll w = 1;
for(register int k = 0;k < i; ++ k,w = w * wn % mod)
{
ll x = a[j + k],y = w * a[i + j + k] % mod;
a[j + k] = (x + y) % mod;
a[i + j + k] = ((x - y)%mod + mod) %mod;
}
}
}
if(f == -1)
{
rep(i,0,len-1) a[i] = a[i] * inv % mod;
}
}
inline void mul(ll*a,ll*b,ll m)
{
rep(i,0,len - 1) c[i] = a[i],d[i] = b[i];
ntt(c,1); ntt(d,1);
rep(i,0,len - 1) c[i] = c[i] * d[i] % mod,a[i] = 0;
ntt(c,-1);
rep(i,0,len - 1) a[i%m] = (a[i%m] + c[i]) % mod;
}
inline void solve()
{
inv = quick_pow(len,mod - 2,mod);
a[ind[1]] = 1;
rep(i,1,S) if(s[i]) b[ind[s[i]]] = 1;
while(n)
{
if(n&1) mul(a,b,m - 1);
mul(b,b,m - 1);
n >>= 1;
}
}
int main()
{
read(n); read(m); read(x); read(S);
rep(i,1,S) read(s[i]);
get_g(m);
rep(i,0,m - 2) ind[quick_pow(g,i,m)] = i;
for(len = 1; len <= m * 2; len <<= 1) ++ L;
rep(i,0,len - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1)<<(L - 1));
solve();
printf("%lld\n",a[ind[x]]);
return 0;
}