Description
小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数
列,数列中的每个数都属于集合S。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:
给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为
,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大
,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。
Input
一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。 第二行,|S|个整数,表示集合S中的所有元素。
1<=N<=10^9,3<=M<=8000,M为质数 0<=x<=M-1,输入数据保证集合S中元素不重复x∈[1,m-1]
集合中的数∈[0,m-1]
Output
一行,一个整数,表示你求出的种类数mod 1004535809的值。
Sample Input
4 3 1 2
1 2
Sample Output
8
HINT
【样例说明】
可以生成的满足要求的不同的数列有(1,1,1,1)、(1,1,2,2)、(1,2,1,2)、(1,2,2,1)、
(2,1,1,2)、(2,1,2,1)、(2,2,1,1)、(2,2,2,2)
题解
NTT入门题
首先求出m的原根g,那么1~m-1的所有数都可用g的次幂形式表示出
譬如 x=gi x = g i , y=gj y = g j ,对于 x∗y x ∗ y ,就相当于 gi+j g i + j 了
那么对于两数相乘等于次数相加的情况,可以想到我们可以对输入的S构造生成函数f
构造生成函数后求这个函数的n次幂,由于一个数可以反复取,所以无需容斥去重
求n次幂的时候,我们会发现,两个次数为m的多项式相乘时,次数会上涨至2*m。这时候我们可以暴力把后m~2*m这一段给加到前面来,因为是取模的嘛
取模一定要勤快呀说不定哪里就爆longlong了
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
const int MAXN=21000;
struct node
{
LL m[MAXN];
}A,ans;
const LL MOD=1004535809;
int R[21000],L,n,M,x,s;
LL pow_mod(LL a,int b,LL mod)
{
LL ans=1LL%mod;
while(b)
{
if(b&1)ans=ans*a%mod;
a=a*a%mod;b>>=1;
}
return ans;
}
void NTT(LL *y,int len,int op)
{
for(int i=0;i<len;i++)if(i<R[i])swap(y[i],y[R[i]]);
for(int i=1;i<len;i<<=1)
{
LL wn=pow_mod(3,(MOD-1)/(i*2),MOD);if(op==-1)wn=pow_mod(wn,MOD-2,MOD);
for(int j=0;j<len;j+=(i<<1))
{
LL w=1;
for(int k=0;k<i;k++)
{
LL u=y[j+k],v=w*y[j+k+i]%MOD;
y[j+k]=(u+v)%MOD;
y[j+k+i]=(u-v+MOD)%MOD;
w=w*wn%MOD;
}
}
}
if(op==-1)
{
LL inv=pow_mod(len,MOD-2,MOD);
for(int i=0;i<len;i++)y[i]=y[i]*inv%MOD;
}
}
LL b[MAXN],P[MAXN];
int len;
void mul(LL *ret,node x,node y,int len)
{
NTT(x.m,len,1);NTT(y.m,len,1);
for(int i=0;i<len;i++)x.m[i]=(x.m[i]*y.m[i])%MOD;
NTT(x.m,len,-1);
for(int i=0;i<len;i++)ret[i]=0;
for(int i=0;i<len;i++)ret[i%(M-1)]=(ret[i%(M-1)]+x.m[i])%MOD;
}
void sol()
{
int b=n;
while(b)
{
if(b&1)mul(ans.m,ans,A,len);
mul(A.m,A,A,len);b>>=1;
}
}
int get_root(int s)
{
int q[1100]={0};
for(int i=2;i<s-1;i++)
if((s-1)%i==0)q[++q[0]]=i;
for(int i=2;;i++)
{
bool bk=true;
for(int j=1;j<=q[0];j++)
{
if(pow_mod(i,q[j],s)==1)bk=false;
if(bk==false)break;
}
if(bk==true)return i;
}
return -1;
}
int main()
{
scanf("%d%d%d%d",&n,&M,&x,&s);
for(int i=1;i<=s;i++)scanf("%lld",&b[i]);
int tmp=1,g=get_root(M);
for(int i=0;i<M-1;i++)
{
P[tmp]=i;
tmp=tmp*g%M;
}
for(int i=1;i<=s;i++)if(b[i])A.m[P[b[i]]]=1;
L=0;
for(len=1;len<=M*2;len<<=1)L++;
for(int i=1;i<len;i++)R[i]=(R[i>>1]>>1)|(i&1)<<(L-1);
ans.m[0]=1;
sol();
printf("%lld\n",ans.m[P[x]]);
return 0;
}