题意
题解
如果直接考虑DP
f[i][j]
f
[
i
]
[
j
]
表示当前构造了
i
i
位,余数有多少种方案
然后再构造
g[i][j]
g
[
i
]
[
j
]
表示构造了至少
i
i
位,这个的话当一个累加器就好了
但是这样太慢了,是的
于是要考虑优化
我们考虑到,这种东西应该是可以合并的
于是我们就考虑使用一个类似快速幂的方式来求出这两个东西
如果是奇数的话,那么就先变为偶数,然后剩下一个暴力合并
偶数的话,就拆成
i/2
i
/
2
然后一个数
o=basei2
o
=
b
a
s
e
i
2
然后容易得到
f[j∗o+k]=∑f[j]∗f[k]
f
[
j
∗
o
+
k
]
=
∑
f
[
j
]
∗
f
[
k
]
然后这个显然是一个卷积的形式
你只需要把第一个
f[j]
f
[
j
]
放在
f[j∗o]
f
[
j
∗
o
]
就可以了
然后
g
g
的转移就是
然后操作是一样的
然后就可以了
时间复杂度
O(nlognlogp)
O
(
n
l
o
g
n
l
o
g
p
)
因为是求稳先写暴力的,所以可能代码比较长。。
但是你们也可以参考一下暴力
一开始因为pow的模数不同没有意识到而调了很久
CODE:
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<iostream>
#include<cstring>
using namespace std;
typedef long long LL;
const LL MOD=998244353;
const LL gi=3,ggi=332748118;
const LL N=50005;
LL base,p,x;
LL pow (LL x,LL y,LL p)
{
if (y==1) return x;
LL lalal=pow(x,y>>1,p);
lalal=lalal*lalal%p; //printf("%I64d %I64d\n",x,y);
if (y&1) lalal=lalal*x%p;
return lalal;
}
LL F[N],G[N];
LL f[N],g[N];
LL bin[N];
void ntt (LL *a,LL n,LL o)
{
for (LL u=0;u<n;u++) bin[u]=((bin[u>>1]>>1)|((u&1)*(n>>1)));
for (LL u=0;u<n;u++) if (u<bin[u]) swap(a[u],a[bin[u]]);
for (int u=1;u<n;u<<=1)
{
LL wn=pow(o==1?gi:ggi,(MOD-1)/(u<<1),MOD),w,t;
for (int i=0;i<n;i=i+(u<<1))
{
w=1;
for (int k=0;k<u;k++)
{
t=w*a[u+i+k]%MOD;
a[u+i+k]=(a[i+k]-t+MOD)%MOD;
a[i+k]=(a[i+k]+t)%MOD;
w=w*wn%MOD;
}
}
}
if (o==-1)
{
LL Inv=pow(n,MOD-2,MOD);
for (int u=0;u<n;u++) a[u]=a[u]*Inv%MOD;
}
}
LL a[N],b[N],now;
void solve (LL n)
{
if (n==1)
{
for (LL u='a';u<='z';u++) {f[u%p]++;g[u%p]++;}
return ;
}
if (n&1)
{
solve(n-1);
for (LL u=0;u<p;u++) {F[u]=f[u];f[u]=0;}
for (LL u='a';u<='z';u++)//枚举加上一个什么
for (LL i=0;i<p;i++)
{
LL h=(i*base+u)%p;
f[h]+=F[i];
if (f[h]>=MOD) f[h]-=MOD;
}
for (LL u=0;u<p;u++)
{
g[u]=g[u]+f[u];
if (g[u]>=MOD) g[u]-=MOD;
}
}
else
{
solve(n/2);
/*printf("%I64d\n",n);
printf("f:");for (LL u=0;u<p;u++) printf("%I64d ",f[u]);
printf("\n");
printf("g:");for (LL u=0;u<p;u++) printf("%I64d ",g[u]);
printf("\n");
system("pause");*/
LL o=pow(base,n/2,p);
for (LL u=0;u<p;u++) {F[u]=f[u];G[u]=g[u];f[u]=0;}
for (LL u=0;u<p;u++) a[u]=b[u]=0;
for (LL u=0;u<p;u++) a[(u*o)%p]+=F[u];
for (LL u=0;u<p;u++) b[u]=F[u];
for (LL u=p;u<now;u++) a[u]=b[u]=0;
/*for (LL u=0;u<p;u++)
for (LL i=0;i<p;i++)
f[(u+i)%p]=(f[(u+i)%p]+a[u]*b[i])%MOD;*/
ntt(a,now,1);ntt(b,now,1);
for (int u=0;u<now;u++) a[u]=a[u]*b[u]%MOD;
ntt(a,now,-1);
for (int u=0;u<now;u++) f[u%p]=(f[u%p]+a[u])%MOD;
/*for (LL u=0;u<p;u++)
for (LL i=0;i<p;i++)
{
LL h=(u*o+i)%p;
f[h]=(f[h]+F[u]*F[i]%MOD)%MOD;
}*/
for (LL u=0;u<p;u++) a[u]=b[u]=0;
for (LL u=0;u<p;u++) a[(u*o)%p]+=G[u];
for (LL u=0;u<p;u++) b[u]=F[u];
for (LL u=p;u<now;u++) a[u]=b[u]=0;
/*for (LL u=0;u<p;u++)
for (LL i=0;i<p;i++)
f[(u+i)%p]=(f[(u+i)%p]+a[u]*b[i])%MOD;*/
ntt(a,now,1);ntt(b,now,1);
for (int u=0;u<now;u++) a[u]=a[u]*b[u]%MOD;
ntt(a,now,-1);
for (int u=0;u<now;u++) g[u%p]=(g[u%p]+a[u])%MOD;
/*for (LL u=0;u<p;u++)
for (LL i=0;i<p;i++)
{
LL h=(u*o+i)%p;
g[h]=(g[h]+G[u]*F[i]%MOD)%MOD;
}*/
}
/* printf("%I64d\n",n);
printf("f:");for (LL u=0;u<p;u++) printf("%I64d ",f[u]);
printf("\n");
printf("g:");for (LL u=0;u<p;u++) printf("%I64d ",g[u]);
printf("\n");
system("pause");*/
}
int main()
{
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
LL n;
scanf("%I64d%I64d%I64d%I64d",&n,&base,&p,&x);
now=1;while (now<p) now<<=1;now<<=1;
solve(n);
/*for (int u=0;u<p;u++) printf("%I64d ",f[u]);
printf("\n");*/
//for (int u=0;u<p;u++) printf("%I64d ",g[u]);
printf("%I64d\n",g[x]);
return 0;
}