题目描述
int work(char *s,int n,int base,int p)
{
long long ans=0;
for(int i=1;i<=n;i++)
ans=(ans*base+s[i])%p;
return ans;
}
以上代码为一种字符串hash的写法,给出base和p,试统计长度小于等于n且能使最后hash值为x的字符串(只能包含小写字母)有多少个。
n,p,base≤50000 0≤x<p
分析
hash值可以看成一个base的次数界为n的多项式。
首先考虑最暴力的dp。设f[i][j]表示确定字符串最后i位,hash值模p等于j的方案数。那么枚举倒数第i+1位的字符c,然后f[i][j]转移到
f[i+1][(j+ci)%p]
现在要考虑优化这个dp。假设现在要算f[i][],那么可不可以用f[
⌊i2⌋
][]算出f[i][]呢?
其实是可以的,假设i是偶数,
f[i][(j+ki2)%p]=∑f[i2][j]∗f[i2][k]
然后把
f[i2][k]
存到
a[ki2]
里,把模去掉,就是个卷积的形式了。可以用NTT解决。
当i是奇数,可以先算出f[i-1][],然后暴力算出f[i][]。
但是问题求的是长度小于等于n的字符串的答案。那么可以设
g[i][j]
表示长度小于等于i的答案,j的意义一样。然后用g[i/2][]和f[i/2][]用NTT来求出g[i][]。
时间复杂度
O(nlogplogn)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int N=131085,mo=998244353;
typedef long long LL;
int n,m,x,base,M,t[N],f[N],g[N],h[N],W[N],r,T,Inv,Ans[N],ss[N];
int quick(int x,int t,int mo)
{
if (!t) return 1;
int tmp=quick(x,t>>1,mo);
tmp=(LL)tmp*tmp%mo;
if (t&1) tmp=(LL)tmp*x%mo;
return tmp;
}
void DFT(int *a,int sig)
{
for (int i=0;i<M;i++)
{
int pos=0;
for (int tmp=i,j=0;j<T;j++,tmp/=2) pos=pos*2+(tmp&1);
t[pos]=a[i];
}
for (int l=2;l<=M;l*=2)
{
int half=l>>1,tmp=M/l;
for (int i=0;i<half;i++)
{
int w=(sig==1)?W[i*tmp]:W[M-i*tmp];
for (int k=i;k<M;k+=l)
{
int p=t[k],q=(LL)t[k+half]*w%mo;
t[k]=(p+q)%mo; t[k+half]=(p-q)%mo;
}
}
}
for (int i=0;i<M;i++) a[i]=t[i];
}
void solve(int x)
{
if (x==1) return;
solve(x>>1);
memset(g,0,sizeof(g));
memset(ss,0,sizeof(ss));
int p=quick(base,x>>1,m);
for (int i=0;i<m;i++) g[(LL)i*p%m]=(g[(LL)i*p%m]+f[i])%mo,ss[(LL)i*p%m]=(ss[(LL)i*p%m]+Ans[i])%mo;
DFT(f,1);
DFT(ss,1);
DFT(g,1);
for (int i=0;i<M;i++) ss[i]=(LL)ss[i]*f[i]%mo,f[i]=(LL)f[i]*g[i]%mo;
DFT(f,-1);
DFT(ss,-1);
for (int i=0;i<M;i++) t[i]=(LL)f[i]*Inv%mo;
memset(f,0,sizeof(f));
for (int i=0;i<M;i++) f[i%m]=(f[i%m]+t[i])%mo;
for (int i=0;i<M;i++) t[i]=(LL)ss[i]*Inv%mo;
for (int i=0;i<M;i++) Ans[i%m]=(Ans[i%m]+t[i])%mo;
if (x&1)
{
memset(g,0,sizeof(g));
for (int i=0;i<m;i++) g[(LL)i*base%m]=(g[(LL)i*base%m]+f[i])%mo;
DFT(g,1);
for (int i=0;i<M;i++) f[i]=(LL)g[i]*h[i]%mo;
DFT(f,-1);
for (int i=0;i<M;i++) t[i]=(LL)f[i]*Inv%mo;
memset(f,0,sizeof(f));
for (int i=0;i<M;i++) f[i%m]=(f[i%m]+t[i])%mo;
for (int i=0;i<m;i++) Ans[i]=(Ans[i]+f[i])%mo;
}
}
int main()
{
scanf("%d%d%d%d",&n,&base,&m,&x);
for (int i='a';i<='z';i++) h[i%m]++;
for (M=1;M<m*2;M=M<<1);
T=log(M)/log(2);
Inv=quick(M,mo-2,mo);
W[0]=1; W[1]=quick(3,(mo-1)/M,mo);
for (int i=2;i<=M;i++) W[i]=(LL)W[i-1]*W[1]%mo;
memcpy(f,h,sizeof(f));
memcpy(Ans,h,sizeof(h));
DFT(h,1);
solve(n);
Ans[x]=(Ans[x]+mo)%mo;
printf("%d\n",Ans[x]);
return 0;
}