题意:
存在一个集合S,求长度为N,每一个元素都是S中的元素(可重复),并且该序列所有数的乘积mod M = x 的序列个数。
M是质数,且集合中的所有元素的范围都在[0,M-1]内。
并且x!=0
解析:
因为有M是质数这个特殊条件,所以我们可以求出来M的原根G,之后因为G的0~(phi(M)-1)可以完美替代0~M-1中的数,于是我们可以考虑把S中所有的数用G的几次幂来代替。
至于为什么这样考虑。
因为这样就把我们所需要的乘法转化成了幂的加法。
搞出集合S的生成函数。
由于每个数可以选取多次,所以接下来的问题就是S的生成函数的n次幂的对应的第x次幂项。
我们发现过程其实就是多项式的乘积过程,并且题目要求答案mod 一个原根为3的大质数,所以我们可以考虑用NTT来优化这一过程。
需要注意的是,在多项式乘积的时候,我们每一次要把大于m的系数加到其mod m后的那一项上,也就是说,不要直接消除,而是在乘积的时候把越界的部分转到mod m下。
总复杂度O(lognmlogm)
代码:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define mod 1004535809
#define G 3
#define N 262145
using namespace std;
typedef long long ll;
int n,m,x,s,root;
ll prime[20010];
int pos[17010];
ll a[N],b[N];
int rev[N];
int num[17010];
int tot;
ll mm;
ll quick_my(ll x,ll y,ll MOD)
{
ll ret=1;
while(y)
{
if(y&1)ret=ret*x%MOD;
x=x*x%MOD;
y>>=1;
}
return ret;
}
void get_factor(ll x)
{
tot=0;
for(ll i=2;i*i<=x;i++)
{
if(x%i==0)
{
prime[++tot]=i;
while(x%i==0)x/=i;
}
}
if(x!=1)prime[++tot]=x;
}
bool check(ll x,ll MOD,ll PHI)
{
for(int i=1;i<=tot;i++)
{
if(quick_my(x,PHI/prime[i],MOD)==1)return 0;
}
return 1;
}
int find_primitive_root(ll x)
{
ll tmp=x-1;
get_factor(tmp);
for(int i=2;;i++)
{
if(check(i,x,tmp))return i;
}
}
void NTT(ll *a,int f)
{
for(int i=0;i<m;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int h=2;h<=m;h<<=1)
{
ll wn=quick_my(G,(mod-1)/h,mod);
for(int i=0;i<m;i+=h)
{
ll w=1;
for(int j=0;j<(h>>1);j++,w=w*wn%mod)
{
ll t=w*a[i+j+(h>>1)]%mod;
a[i+j+(h>>1)]=((a[i+j]-t)%mod+mod)%mod;
a[i+j]=(a[i+j]+t)%mod;
}
}
}
if(f==-1)
{
for(int i=1;i<(m>>1);i++)swap(a[i],a[m-i]);
ll inv=quick_my(m,mod-2,mod);
for(int i=0;i<m;i++)a[i]=a[i]*inv%mod;
}
}
ll ret[N];
void get_my(int y)
{
ret[0]=1;
while(y)
{
NTT(b,1);
if(y&1)
{
NTT(ret,1);
for(int i=0;i<m;i++)ret[i]=ret[i]*b[i]%mod;
NTT(ret,-1);
for(int i=m-1;i>=mm-1;i--)ret[i-mm+1]=(ret[i-mm+1]+ret[i])%mod,ret[i]=0;
}
for(int i=0;i<m;i++)b[i]=b[i]*b[i]%mod;
NTT(b,-1);
for(int i=m-1;i>=mm-1;i--)b[i-mm+1]=(b[i-mm+1]+b[i])%mod,b[i]=0;
y>>=1;
}
}
int main()
{
scanf("%d%d%d%d",&n,&m,&x,&s);
for(int i=1;i<=s;i++)scanf("%d",&num[i]);
root=find_primitive_root(m);
ll tmp=1;
for(int i=0;i<m-1;i++)
pos[tmp]=i,tmp=tmp*root%m;
int l=m*2,L=0;
mm=m;
for(m=1;m<=l;m<<=1)L++;
for(int i=0;i<m;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
for(int i=1;i<=s;i++)
if(num[i]!=0)b[pos[num[i]]]++;
get_my(n);
printf("%lld\n",ret[pos[x]]);
}