题面
解法
- 先考虑一个最简单的dp:令 f [ i ] [ j ] f[i][j] f[i][j]表示前 i i i个数的乘积对 m m m取模为 j j j的方案数,转移比较简单,在这里就不写了。
- 但是我们会发现,转移的时候是乘法,并没有特别好的优化方式。
- 注意 m m m是一个质数,一定存在原根 g g g。那么,我们就可以用 g g g的若干次方表示出 [ 1 , m ) [1,m) [1,m)中的所有数。
- 现在我们不妨对原来的状态稍作修改, f [ i ] [ j ] f[i][j] f[i][j]表示为前 i i i个数的乘积对 m m m取模与 g j g^j gj同余,然后转移就是 f [ i ] [ j ] = ∑ k f [ i − 1 ] [ j − k ] × s u m [ k ] f[i][j]=\sum_{k}f[i-1][j-k]\times sum[k] f[i][j]=∑kf[i−1][j−k]×sum[k]
- 假设当前的 f [ i ] f[i] f[i]为 a a a, f [ i − 1 ] f[i-1] f[i−1]为 b b b,那么 a = b ∗ s u m a=b*sum a=b∗sum,所以最后的答案 f [ n ] = f [ 0 ] ∗ s u m n f[n]=f[0]*sum^n f[n]=f[0]∗sumn。
- 因为卷积满足结合律,所以我们可以对 s u m n sum^n sumn进行快速幂,在做乘法的用NTT加速。
- 时间复杂度: O ( m log n log m ) O(m\log n\log m) O(mlognlogm)
【注意事项】
- 原根可以暴力找出,并不会存在很大的原根。
- 在乘法的时候需要注意,中途的结果长度可能大于 m m m,那么对于 i ∈ [ m − 1 , 2 m ) i\in [m-1,2m) i∈[m−1,2m), a [ i − m + 1 ] + = a [ i ] a[i-m+1]+=a[i] a[i−m+1]+=a[i],因为 φ ( m ) = m − 1 \varphi(m)=m-1 φ(m)=m−1, g g g为原根,所以最小循环节为 m − 1 m-1 m−1,那么答案也要相应地加上去。
代码
#include <bits/stdc++.h>
using namespace std;
template <typename T> void chkmax(T &x, T y) {x = x > y ? x : y;}
template <typename T> void chkmin(T &x, T y) {x = x > y ? y : x;}
template <typename T> void read(T &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
const int N = 20010, Mod = 1004535809;
int m, a[N], b[N], c[N], f[N], g[N], num[N], rev[N];
bool check(int n, int g) {
for (int i = 1, cur = g; i < n - 1; i++, cur = cur * g % n)
if (cur == 1) return false;
return true;
}
int calc(int n) {for (int i = 2; ; i++) if (check(n, i)) return i;}
int Pow(int x, int y) {
int ret = 1;
for (; y; y >>= 1, x = 1ll * x * x % Mod)
if (y & 1) ret = 1ll * ret * x % Mod;
return ret;
}
void getrev(int l) {
for (int i = 0; i < (1 << l); i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l - 1);
}
void NTT(int *a, int n, int fl) {
for (int i = 0; i < n; i++)
if (rev[i] < i) swap(a[i], a[rev[i]]);
for (int i = 1; i < n; i <<= 1) {
int wn = Pow(3, fl == 1 ? (Mod - 1) / (i << 1) : Mod - 1 - (Mod - 1) / (i << 1));
for (int j = 0, r = i << 1; j < n; j += r) {
int w = 1;
for (int k = 0; k < i; k++, w = 1ll * w * wn % Mod) {
int tx = a[j + k], ty = 1ll * w * a[i + j + k] % Mod;
a[j + k] = (tx + ty) % Mod, a[i + j + k] = (tx - ty + Mod) % Mod;
}
}
}
if (fl == -1) {
int tmp = Pow(n, Mod - 2);
for (int i = 0; i < n; i++) a[i] = 1ll * a[i] * tmp % Mod;
}
}
void mul(int *a, int *tx, int *ty, int n) {
for (int i = 0; i < n; i++) b[i] = tx[i], c[i] = ty[i];
NTT(b, n, 1), NTT(c, n, 1);
for (int i = 0; i < n; i++) a[i] = 1ll * b[i] * c[i] % Mod;
NTT(a, n, -1);
for (int i = m - 1; i < n; i++) a[i - m + 1] = (a[i - m + 1] + a[i]) % Mod, a[i] = 0;
}
int calc(int n, int len, int x) {
g[0] = 1;
while (n) {
if (n & 1) mul(g, g, f, len);
n >>= 1, mul(f, f, f, len);
}
return g[num[x]];
}
int main() {
int n, tx, s;
read(n), read(m), read(tx), read(s);
int t = calc(m);
for (int i = 1, cur = t; i < m - 1; i++, cur = cur * t % m) num[cur] = i;
for (int i = 1; i <= s; i++) {
int x; read(x);
if (x) f[num[x]]++;
}
int l = 0, len = 1;
while (len <= 2 * m) l++, len <<= 1; getrev(l);
cout << calc(n, len, tx) << "\n";
return 0;
}