序列统计
Description
小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:
给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。
Input
一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。
第二行,|S|个整数,表示集合S中的所有元素。
1<=N<=10^9,3<=M<=8000,M为质数
0<=x<=M-1,输入数据保证集合S中元素不重复x∈[1,m-1]
集合中的数∈[0,m-1]
Output
一行,一个整数,表示你求出的种类数mod 1004535809的值。
Sample Input
4 3 1 2
1 2
Sample Output
8
【样例说明】
可以生成的满足要求的不同的数列有(1,1,1,1)、(1,1,2,2)、(1,2,1,2)、(1,2,2,1)、
(2,1,1,2)、(2,1,2,1)、(2,2,1,1)、(2,2,2,2)
对原根的认识停留在基本性质和能用在NTT上两方面……
窝太菜了QAQ
思路:
首先可以得到一个暴力dp,设
f[i][j]
代表数列长度为
i
,当前乘积模
然后考虑优化这个暴力。
然而这个乘法很难优化。
于是考虑化乘法为加法。
由于
m
为质数,考虑使用原根。
根据原根的性质,对于质数
于是,使用原根的一个次幂
gi
去表示每个
ai
,则
ai∗aj≡gi+j
。
于是就成功地将乘法化为了加法!
令
gbi=ai
,
gy=x
,那么转移方程可以变成这样:
答案即为 f[n][y] 。
可以发现转移是一个循环卷积的形式。
于是上多项式快速幂即可。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
inline ll read()
{
ll x=0;char ch=getchar();
while(ch<'0' || '9'<ch)ch=getchar();
while('0'<=ch && ch<='9')x=x*10+(ch^48),ch=getchar();
return x;
}
const int N=32009;
const int md=1004535809;
ll n,m,x,s,g,l;
ll a[N],ha[N],f[N],c[N],h[N],rev[N];
inline ll qpow(ll a,ll b,ll p=md)
{
ll ret=1;
while(b)
{
if(b&1)ret=ret*a%p;
a=a*a%p;b>>=1;
}
return ret;
}
inline ll calc(ll x)
{
for(int i=2;i;i++)
{
for(int j=2;j*j<x;j++)
if(qpow(i,(x-1)/j,x)==1)
goto nxt;
return i;
nxt:;
}
}
inline void init(int n)
{
for(int i=0;i<n;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
}
inline void NTT(ll *a,int n,int f)
{
for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int h=2;h<=n;h<<=1)
{
ll w=qpow(3,(md-1)/h);
if(f)w=qpow(w,md-2);
for(int j=0;j<n;j+=h)
{
ll wn=1ll;
for(int k=j;k<j+(h>>1);k++)
{
ll x=a[k],y=a[k+(h>>1)]*wn%md;
a[k]=(x+y)%md;
a[k+(h>>1)]=(x-y+md)%md;
wn=wn*w%md;
}
}
}
if(f)
for(ll i=0,inv=qpow(n,md-2);i<n;i++)
a[i]=a[i]*inv%md;
}
inline void mul(ll *a,ll *b,ll *c)
{
static ll d[N];
memset(d,0,sizeof(d));
for(int i=0;i<m;i++)
d[i]=a[i];
NTT(d,l,0);NTT(b,l,0);
for(int i=0;i<l;i++)
d[i]=d[i]*b[i]%md;
NTT(d,l,1);NTT(b,l,1);
for(int i=0;i<m;i++)
c[i]=d[i];
for(int i=m;i<l;i++)
(c[i%m]+=d[i])%=md;
}
int main()
{
n=read();m=read();
x=read();s=read();
for(int i=1;i<=s;i++)
a[i]=read();
ll tmp=(g=calc(m));
ha[0]=-1;ha[1]=0;ha[tmp]=1;
for(int i=2;i<m-1;i++)
ha[tmp=tmp*g%m]=i;
for(int i=1;i<=s;i++)
a[i]=ha[a[i]];
x=ha[x];m--;
for(l=1;l<(m<<1);l<<=1);
for(int i=1;i<=s;i++)
if(~a[i])
f[a[i]]++;
c[0]=1;init(l);
while(n)
{
if(n&1)mul(c,f,c);
mul(f,f,f);n>>=1;
}
printf("%lld\n",c[x]);
return 0;
}