测试地址:序列统计
做法:本题需要用到NTT+循环卷积+快速幂。
这个题我们很快就想出状态转移:令
f(i,j)
f
(
i
,
j
)
为前
i
i
个数的乘积模的结果为
j
j
的数列方案数,那么有:
其中
[U]
[
U
]
表示表达式
U
U
成立时值为,否则值为
0
0
。
但是这个式子是的,无法承受。
由于运算是乘法,所以没办法使用NTT求卷积,那有没有什么办法把乘法变成加法呢?
我们在高中学过对数,对数满足
loga(xy)=logax+logay
l
o
g
a
(
x
y
)
=
l
o
g
a
x
+
l
o
g
a
y
,这就把乘法变成了加法,但是这是在实数域中,在模意义域中有没有类似的东西呢?
因为
m
m
是质数,所以一定存在一个原根,根据原根的性质,
g0,g1,...,gm−2
g
0
,
g
1
,
.
.
.
,
g
m
−
2
在模
m
m
意义下各不相同,因此我们类似的定义离散对数为使得
gx%m=y
g
x
%
m
=
y
的
x
x
,于是我们有:
于是我们用
I(j)
I
(
j
)
替代
f(i,j)
f
(
i
,
j
)
中的第二维下标
j
j
,并用替换
l
l
成为中的元素,称为
S0
S
0
,于是式子变为:
f(i,j)=∑0≤k<m−1∑l∈S0[(k+l)%(m−1)=j]f(i−1,k)
f
(
i
,
j
)
=
∑
0
≤
k
<
m
−
1
∑
l
∈
S
0
[
(
k
+
l
)
%
(
m
−
1
)
=
j
]
f
(
i
−
1
,
k
)
这样就把原来转移时候的模意义下的乘法变成了模意义下的加法。接下来我们定义向量
F(i)
F
(
i
)
为
f(i,0),...,f(i,m−2)
f
(
i
,
0
)
,
.
.
.
,
f
(
i
,
m
−
2
)
这一些数,我们发现
F(i)
F
(
i
)
就是
F(i−1)
F
(
i
−
1
)
和另一个向量
A
A
的一个循环卷积,其中,只用将卷积后下标
i
i
大于的值都累加在下标为
i%(m−2)
i
%
(
m
−
2
)
的位置上即可。用NTT优化求循环卷积的过程,时间复杂度降为
O(nmlogm)
O
(
n
m
log
m
)
。
然而还是不够,
n
n
达到了,意识到循环卷积运算满足交换律和结合律,用快速幂即可加速到
O(mlognlogm)
O
(
m
log
n
log
m
)
,至于原根可以直接
O(m2)
O
(
m
2
)
暴力求(实际上常数小的多),这样就解决了这道题。
有的同学可能注意到,上述方法不能处理
x=0
x
=
0
的情况,BZOJ上的题面说是有
x=0
x
=
0
,但洛谷上没有
x=0
x
=
0
,并且这份代码在两边都过了,所以推断数据应该不存在这种情况,所以无需特判。
以下是本人代码:
#include <bits/stdc++.h>
#define ll long long
#define mod 1004535809
#define g 3
using namespace std;
int n,m,x,s,p[8010],save,r[30010];
ll M[30010]={0},S[30010]={0};
bool vis[8010];
ll power(ll a,ll b)
{
ll s=1,ss=a;
while(b)
{
if (b&1) s=(s*ss)%mod;
ss=(ss*ss)%mod,b>>=1;
}
return s;
}
void NTT(ll *a,int n,int type)
{
for(int i=0;i<n;i++)
if (i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1)
{
ll W=power(g,(mod-1)/(mid<<1));
if (type==-1) W=power(W,mod-2);
for(int l=0,r=mid<<1;l<n;l+=r)
{
ll w=1;
for(int k=0;k<mid;k++,w=(w*W)%mod)
{
ll x=a[l+k],y=(w*a[l+mid+k])%mod;
a[l+k]=(x+y)%mod;
a[l+mid+k]=((x-y)%mod+mod)%mod;
}
}
}
if (type==-1)
{
int inv=power(n,mod-2);
for(int i=0;i<n;i++)
a[i]=(a[i]*inv)%mod;
}
}
void power_conv(ll *a,int b,int n,ll *ans)
{
while(b)
{
NTT(a,n,1);
if (b&1)
{
NTT(ans,n,1);
for(int i=0;i<n;i++)
ans[i]=(ans[i]*a[i])%mod;
NTT(ans,n,-1);
for(int i=0;i<n;i++)
if (i>=save) ans[i%save]=(ans[i%save]+ans[i])%mod,ans[i]=0;
}
for(int i=0;i<n;i++)
a[i]=(a[i]*a[i])%mod;
NTT(a,n,-1);
for(int i=0;i<n;i++)
if (i>=save) a[i%save]=(a[i%save]+a[i])%mod,a[i]=0;
b>>=1;
}
}
int main()
{
scanf("%d%d%d%d",&n,&m,&x,&s);
for(int i=2;i<=m;i++)
{
memset(vis,0,sizeof(vis));
bool flag=1;
for(int j=0,w=1;j<m-1;j++,w=(w*i)%m)
{
if (vis[w]) {flag=0;break;}
else p[w]=j,vis[w]=1;
}
if (flag) break;
}
for(int i=1;i<=s;i++)
{
int v;
scanf("%d",&v);
v%=m;
if (v) M[p[v]]++;
}
int bit=0,t=1;
while(t<(m<<1)) bit++,t<<=1;
r[0]=0;
for(int i=1;i<t;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
save=m-1,m=t;
S[0]=1;
power_conv(M,n,m,S);
printf("%lld",S[p[x]]);
return 0;
}