题目大意: 给出一个集合S,元素都是小于m的整数。现给定整数x,求所有长度为n且每个元素都属于集合S的数列中满足数列中所有数的乘积mod m的值等于x的不同的数列的有多少
乘积为定值不怎么会算,而和为定值就是一个卷积的形式,于是将数列转成其对数做加法。
数列中有n个数,快速幂即可。
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MOD 1004535809
#define N 17005
using namespace std;
typedef long long LL;
const int G=3,DFT=0,IDFT=1;
int n,m,X,S,len,tot,prime[N],pos[N],w[2][N],Log[N];
LL f_pow(LL x,LL y,LL mod) {
LL tmp=1;
while(y) {
if(y&1) (tmp*=x)%=mod;
(x*=x)%=mod;
y>>=1;
}
return tmp;
}
inline int get_inv(int x) { return f_pow(x,MOD-2,MOD); }
void init() {
len=1;
while(len<=m*2) len*=2;
for(int i=0;i<len;i++) {
pos[i]=pos[i/2]/2;
if(i&1) pos[i]|=len/2;
}
w[0][0]=w[0][len]=1;
int tmp=f_pow(G,(MOD-1)/len,MOD);
for(int i=1;i<len;i++) w[0][i]=(LL)w[0][i-1]*tmp%MOD;
for(int i=0;i<=len;i++) w[1][i]=w[0][len-i];
return ;
}
class Data {
private:
void NTT(int x[],int mode) {
for(int i=0;i<len;i++)
if(i<pos[i]) swap(x[i],x[pos[i]]);
for(int i=2;i<=len;i*=2) {
int step=i/2;
for(int j=0;j<len;j+=i) {
int limit=j+step;
for(int k=j;k<limit;k++) {
int l=x[k],r=(LL)x[k+step]*w[mode][len/i*(k-j)]%MOD;
x[k]=(l+r)%MOD, x[k+step]=(l-r+MOD)%MOD;
}
}
}
return ;
}
public:
int a[N];
Data() { memset(a,0,sizeof a); }
int& operator [] (const int& x) { return a[x]; }
Data& operator *= (const Data& rhs) {
static int b[N];
memcpy(b,rhs.a,sizeof b);
NTT(a,DFT), NTT(b,DFT);
for(int i=0;i<len;i++) a[i]=(LL)a[i]*b[i]%MOD;
NTT(a,IDFT);
for(int i=m-1;i<=m*2-4;i++)
(a[i-m+1]+=a[i])%=MOD, a[i]=0;
int inv=get_inv(len);
for(int i=0;i<=m-2;i++) a[i]=(LL)a[i]*inv%MOD;
return *this;
}
}a,ini;
Data f_pow(Data x,int y) {
Data tmp=ini;
while(y) {
if(y&1) tmp*=x;
x*=x;
y>>=1;
}
return tmp;
}
bool check(int x,int mod) {
static bool k[N];
for(int i=1;i<mod;i++) k[i]=false;
for(int i=1;i<mod;i++) k[f_pow(x,i,mod)]=true;
for(int i=1;i<mod;i++)
if(!k[i]) return false;
return true;
}
int get_g(int mod) {
int x=1;
while(!check(x,mod)) x++;
return x;
}
int main() {
scanf("%d%d%d%d",&n,&m,&X,&S);
init();
int g=get_g(m);
ini.a[0]=1;
for(int i=0,now=1;i<m-1;i++,(now*=g)%=m) Log[now]=i;
for(int i=1;i<=S;i++) {
int x;
scanf("%d",&x);
if(x) a[Log[x]]=1;
}
printf("%d\n",f_pow(a,n)[Log[X]]);
return 0;
}