Description
小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。
小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。
Input
一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。第二行,|S|个整数,表示集合S中的所有元素。
Output
一行,一个整数,表示你求出的种类数mod 1004535809的值。
Sample Input
4 3 1 2
1 2
1 2
Sample Output
8
先暴力推了一波dp式子,看上去有点像卷积,可是第i项是和i/j项有关的。
由于m是质数,所以可以尝试用原根,因为原根g^i(0~p-2)在模p意义下两两不同并且不为0,所以对一个数x,若x=g^i mod (m - 1),那么就可以用i代替x,这样就可以把问题转换为在一个数集中选取n个数它们的和在模(p-1)意义下的值是x,这样dp中第i项就是和第i-j项有关的。
原根的求法,对于一个质数p,对p-1分解质因数p=p1^q1*p2^q2*...*pn^qn,然后我们从2开始枚举,如果i^((p-1)/pi)均不为1,那么i就是p的一个原根,因为原根不大所以这样的方法是没有问题的。
但是有一点要注意,这题中输入的数可能为0,这时候应该直接无视掉,原根无法处理这样的情况。
f[i][j]代表选i个数,数的和在模(m-1)意义下为j的方案数,那么f[i][j]=f[i-1][(j-num[i]+m-1) mod (m-1)]然后这就是个循环卷积,我们对长度为2m的多项式做多项式乘法,然后大于等于m-1的项累加到i mod (m-1)的项上,并清0。
可是那个n很大,我们发现我们只要计算出f[1],最后的多项式就是f[1]^n,快速幂就ok了
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int MOD = 1004535809;
const int MAXN = 8005;
int n, m, i, j, k, num[MAXN], len, x, next[MAXN];
int a[16389], bit[16389], po[16389], b[16389], c[16389];
inline int get()
{
char c;
while ((c = getchar()) < 48 || c > 57);
int res = c - 48;
while ((c = getchar()) >= 48 && c <= 57)
res = res * 10 + c - 48;
return res;
}
inline int ksm(int x, int y, int z)
{
int b = 1;
while (y)
{
if (y & 1) b = (long long)b * x % z;
x = x * (long long)x % z;
y >>= 1;
}
return b;
}
inline int fuck(int x)
{
int y = x;
int cnt = 0;
for(int i = 2; i * i <= y; i ++)
if (x % i == 0)
{
a[++cnt] = i;
while (x % i == 0)
x /= i;
}
if (x != 1) a[++cnt] = x;
for(int i = 2; i <= y; i ++)
{
bool fp = 0;
for(int j = 1; j <= cnt; j ++)
if (ksm(i, y / a[j], y + 1) == 1)
{
fp = 1;
break;
}
if (!fp) return i;
}
}
inline int ntt_init(int m)
{
int n = 1, nn = 0;
while (n < m) n <<= 1, nn ++;
int g = ksm(3, (MOD - 1) / n, MOD);
po[0] = 1;
for(int i = 1; i <= n; i ++)
po[i] = po[i - 1] * (long long)g % MOD, bit[i] = (bit[i >> 1] >> 1) | ((i & 1) << nn - 1);
return nn;
}
inline void ntt(int *a, int nn, int ty)
{
int n = 1 << nn;
for(int i = 0; i < n; i ++)
if (i < bit[i]) swap(a[i], a[bit[i]]);
for(int k = 1; k <= nn; k ++)
{
int len = 1 << k, wn = (ty == 1) ? po[n / len] : po[n - n / len];
for(int j = 0; j < n; j += len)
{
int m = len >> 1, w = 1;
for(int i = j; i < j + m; i ++)
{
int l = a[i], t = a[i + m] * (long long)w % MOD;
a[i] = (l + t) % MOD;
a[i + m] = (l - t + MOD) % MOD;
w = w * (long long)wn % MOD;
}
}
}
}
int main()
{
n = get(); m = get(); x = get(); len = get();
int g = fuck(m - 1);
int w = 1;
for(i = 1; i <= m - 1; i ++)
{
next[w] = i - 1;
w = w * g % m;
}
x = next[x];
for(i = 1; i <= len; i ++)
num[i] = get(), num[i] = next[num[i]];
m --;
for(i = 1; i <= len; i ++)
if (num[i]) b[num[i]] = 1;
memset(a, 0, sizeof(a));
a[0] = 1;
int nn = ntt_init(m * 2);
int N = 1 << nn, inv = ksm(N , MOD - 2, MOD);
while (n)
{
if (n & 1)
{
for(i = 0; i <= N; i ++)
c[i] = b[i];
ntt(a, nn, 1);
ntt(c, nn, 1);
for(i = 0; i < N; i ++)
a[i] = (long long)a[i] * c[i] % MOD;
ntt(a, nn, -1);
for(i = 0; i < N; i ++)
a[i] = (long long)a[i] * inv % MOD;
for(i = m; i < N; i ++)
a[i - m] = (a[i - m] + a[i]) % MOD, a[i] = 0;
}
if (n == 1) break;
ntt(b, nn, 1);
for(i = 0; i < N; i ++)
b[i] = (long long)b[i] * b[i] % MOD;
ntt(b, nn, -1);
for(i = 0; i < N; i ++)
b[i] = (long long)b[i] * inv % MOD;
for(i = m; i < N; i ++)
b[i - m] = (b[i - m] + b[i]) % MOD, b[i] = 0;
n >>= 1;
}
cout << a[x];
}