题目描述
传送门
题目大意:给定元素在[1,m)内的整数集合S,求有多少个长度为n的数列满足所有元素属于S且mod m下的积为x,元素可以重复出现。
题解
这道题看到之后容易想到dp
f[i][j]
表示的是选到第i个数,乘积%m的方案数
f[i][j]=∑i=1|s|f[i−1][j∗inv[S[i]]
其中
|s|
表示的是集合的元素个数。
乘法不利于后面的计算,由于
M
为质数,我们引进原根的概念。
设
假设一个数g对于P来说是原根,那么
简单来说, gimod p≠gjmod p ( p 为素数)
其中
求原根目前的做法只能是从2开始枚举,然后暴力判断
g(P−1)=1(modP)
是否当且仅当指数为
P−1
的时候成立
而由于原根一般都不大,所以可以暴力得到.
那么所有元素的积都可以转化成和形式,即
∏si∈seqsi≡x mod m
由于原根的幂可以表示
[1,m)
内的数
令
hi,hx
满足
ghi mod m=ei,ghx mod =x
∏ghi≡g∑hi≡x=ghx mod m
∑hi≡hx mod m−1
于是问题转化成了在
S
中可重复的选n个元素使元素和为指定值。
那么我们可以将上面的DP式子进行变形
我们用g[i]表示状态i的方案数,f[i]表示原根的i次幂是否是集合中的元素
g[i]=∑jg[j]∗f[i−j]
这是一个卷积的形式,那么我们可以用
NTT
来优化
那么什么是
NTT
呢?
NTT
是快速数论变换,如果取模的数是一个质数
P
,且
具体的实现过程于
FFT
基本相同,只需要将原根替换,并且注意及时取模即可。
可以参见代码
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 17000
#define LL long long
#define p 1004535809
using namespace std;
int n,m,x,s,vis[N],pr,R[N],M,L,X,cnt;
LL f[N],g[N],a[N],b[N];
LL quickpow(LL num,LL x,LL p1)
{
LL base=num%p1; LL ans=1;
while (x) {
if (x&1) ans=ans*base%p1;
x>>=1;
base=base*base%p1;
}
return ans%p;
}
int calc(int x)//求原根,当且仅当i=p-1是g^i=1 (mod p)
{
if (x==2) return 1;
for (int i=2;i;i++) {
bool pd=1;
for (int j=2;j*j<x;j++)
if (quickpow(i,(x-1)/j,x)==1) {
pd=false;
break;
}
if (pd) return i;
}
}
void NTT(LL x1[N],int n,int opt)
{
int j;
for (int i=0;i<n;i++)
if (i<R[i]) swap(x1[i],x1[R[i]]);
for (int i=1;i<n;i<<=1) {
LL wn=quickpow(3,(p-1)/(i<<1),p);
for (int p1=i<<1,j=0;j<n;j+=p1) {
LL w=1;
for (int k=0;k<i;k++,w=(w*wn)%p) {
LL x=x1[j+k],y=(w*x1[j+k+i])%p;
x1[j+k]=(x+y)%p; x1[j+k+i]=(x-y+p)%p;
}
}
}
if (opt==-1) reverse(x1+1,x1+n);
}
void mul(LL g[N],LL f[N])
{
for (int i=0;i<n;i++) a[i]=g[i]%p;
for (int i=0;i<n;i++) b[i]=f[i]%p;
NTT(a,n,1); NTT(b,n,1);
for (int i=0;i<n;i++) a[i]=(a[i]*b[i])%p;
NTT(a,n,-1);
LL inv=quickpow(n,p-2,p);
for (int i=0;i<n;i++) a[i]=(a[i]*inv)%p;
for (int i=0;i<m-1;i++)
g[i]=(a[i]+a[i+m-1])%p;
}
void solve(int x)
{
g[0]=1;
while (x) {
if (x&1) mul(g,f);
x>>=1;
mul(f,f);
}
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d%d%d",&cnt,&m,&X,&s);
pr=calc(m);
// cout<<pr<<endl;
for (int i=1;i<=s;i++) {
int x; scanf("%d",&x);
vis[x]=1;
}
int pos=-1;
for (int i=0,j=1;i<m-1;i++,j=(j*pr)%m) {
if (vis[j]) f[i]=1;
if (j==X) pos=i;
}
M=(m-1)*2;
for (n=1;n<=M;n<<=1) L++;
for (int i=0;i<n;i++)
R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
solve(cnt);
if (pos!=-1) printf("%I64d\n",g[pos]%p);
else printf("0\n");
}