本来不咋爱写博客的,但是这题实在让我受益匪浅,所以记录一下。 = =
题目大意是求所有长度为n的只含大写字母的字符串按所给函数进行hash,hash值有多少对相同,结果mod1e6+3。
思路:
刚开始想到用所给的hash函数进行dp,,即记录到第i个字符hash值为j的个数用26个字母递推到第i+1个字符对应的hash值,加上dp(i,j),是但是考虑了下复杂度是O(26*m*n),所以放弃了。
后来想是不是hash值分布均匀,于是想用26^n/m来弄,结果发现26^n并不能整除m,所以无奈pass。
在后来就是之后找到题解了,思想和开始的dp思想相似,但是将m种可能的hash值排成了多项式,变成多项式以后就可以用FFT(NNT+CRT)进行优化,每次推进的复杂度变成了mlogm,在此基础上,可以二分n来做FFT,将总复杂度缩短到了O(mlogmlogn),可喜可贺。
然后附上WA代码:
#include <vector>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
using namespace std;
const int maxn=1<<16;
const int mod=1e6+3;
const double pi = acos(-1.0);
struct complex
{
double a, b;
complex(double aa = 0.0, double bb = 0.0)
{
a = aa;
b = bb;
}
complex operator +(const complex &e)
{
return complex(a + e.a, b + e.b);
}
complex operator -(const complex &e)
{
return complex(a - e.a, b - e.b);
}
complex operator *(const complex &e)
{
return complex(a * e.a - b * e.b, a * e.b + b * e.a);
}
};
void change(complex * y, long long len)
{
long long i, j, k;
for (i = 1, j = len / 2; i < len - 1; i++)
{
if (i < j) swap(y[i], y[j]);
k = len / 2;
while (j >= k)
{
j -= k;
k /= 2;
}
if (j < k) j += k;
}
}
void fft(complex *y, long long len, long long on)
{
change(y, len);
for (int h = 2; h <= len; h <<= 1)
{
complex wn(cos(-on * 2 * pi / h), sin(-on * 2 * pi / h));
for (int j = 0; j < len; j += h)
{
complex w(1, 0);
for (int k = j; k < j + h / 2; k++)
{
complex u = y[k];
complex t = w * y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (on == -1)
for (int i = 0; i < len; i++)
y[i].a /= len;
}
complex x[maxn],y[maxn];
void multiply(long long f[],long long g[],int m)
{
int len = 1;
while(len < 2*m)
len <<= 1;
for(int i=0;i<m;i++)
x[i]=complex(f[i],0);
for(int i=m;i<len;i++)
x[i]=complex(0,0);
for(int i=0;i<m;i++)
y[i]=complex(g[i],0);
for(int i=m;i<len;i++)
y[i]=complex(0,0);
fft(x,len,1);
fft(y,len,1);
for(int i=0;i<len;i++)
x[i]=x[i]*y[i];
fft(x,len,-1);
for(int i=0;i<m;i++)
f[i]=(long long)(x[i].a+0.5);
for(int i=m;i<len;i++)
{
f[i%m]=(f[i%m]+(long long)(x[i].a+0.5))%mod;
f[i]=0;
}
}
int main()
{
int n,m,p;
scanf("%d%d%d",&n,&m,&p);
long long ans[maxn],t[maxn],g[maxn];
memset(ans,0,sizeof(ans));
memset(t,0,sizeof(t));
memset(g,0,sizeof(g));
ans[0]=1;
for(int i='A'; i<='Z'; i++)
g[i%m]++;
while(n)
{
if(n&1)
{
int ck=0;
memset(t,0,sizeof(t));
for(int i=0; i<m; i++)
{
t[ck]=(t[ck]+ans[i])%mod;
ck=(ck+p)%m;
}
multiply(t,g,m);
memcpy(ans,t,sizeof(t));
}
int ck=0;
memset(t,0,sizeof(t));
for(int i=0; i<m; i++)
{
t[ck]=(t[ck]+g[i])%mod;
ck=(ck+p)%m;
}
multiply(t,g,m);
memcpy(g,t,sizeof(t));
p=(1ll*p*p)%m;
n>>=1;
}
long long ans1=0;
for(int i=0;i<m;i++)
{
if(ans[i]<2)
ans[i]+=mod;
ans1=(ans1+(1ll*ans[i]*(ans[i]-1)/2)%mod)%mod;
//cout<<i<<" "<<ans[i]<<endl;
}
printf("%I64d\n",ans1);
}
为什么会WA!!! 我想了好久各个地方应该都没问题啊,最后只能找dalao的代码观摩了下,发现他们的FFT模板精度(或者NTT)比我高,无奈只能找了一个厉害点的模板在FFT了一次,这次终于A了- -。
#include <vector>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
using namespace std;
typedef long double db;
const int MAXN=1<<16;
const int Mod=1e6+3;
const double PI=acos(-1.0);
struct Complex
{
db x,y;
Complex(db _x=0,db _y=0):x(_x),y(_y){}
Complex operator + (const Complex &b)const
{
return Complex(x+b.x,y+b.y);
}
Complex operator - (const Complex &b)const
{
return Complex(x-b.x,y-b.y);
}
Complex operator * (const Complex &b)const
{
return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
}
Complex operator / (const db &b)const
{
return Complex(x/b,y/b);
}
};
void change(Complex y[],int len)
{
for(int i=1,j=len/2;i<len-1;i++)
{
if(i<j)swap(y[i],y[j]);
int k=len/2;
while(j>=k)
{
j-=k;
k/=2;
}
if(j<k)j+=k;
}
}
void fft(Complex y[],int len,int on)
{
change(y,len);
for(int h=2;h<=len;h<<=1)
{
Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
for(int j=0;j<len;j+=h)
{
Complex w(1,0);
for(int k=j;k<j+h/2;k++)
{
Complex u=y[k];
Complex v=w*y[k+h/2];
y[k]=u+v;
y[k+h/2]=u-v;
w=w*wn;
}
}
}
if(on==-1)for(int i=0;i<len;i++)
y[i]=y[i]/len;
}
Complex p1[MAXN],p2[MAXN],p3[MAXN],q1[MAXN],q2[MAXN];
void multiply(int p[],int q[],int m)
{
int t=sqrt(Mod),len=1;
while(len<2*m)len<<=1;
for(int i=0;i<len;i++)
{
p1[i]=(i<m ? p[i]/t : 0);
p2[i]=(i<m ? p[i]%t : 0);
p3[i]=0;
q1[i]=(i<m ? q[i]/t : 0);
q2[i]=(i<m ? q[i]%t : 0);
}
fft(p1,len,1),fft(p2,len,1),fft(q1,len,1),fft(q2,len,1);
for(int i=0;i<len;i++)
{
p3[i]=p1[i]*q2[i]+p2[i]*q1[i];
p1[i]=p1[i]*q1[i];
p2[i]=p2[i]*q2[i];
}
fft(p1,len,-1),fft(p2,len,-1),fft(p3,len,-1);
for(int i=0;i<len;i++)
{
long long t1=p1[i].x+0.5,t2=p2[i].x+0.5,t3=p3[i].x+0.5;
p[i]=(t1*t*t+t*t3+t2)%Mod;
}
for(int i=m;i<len;i++)
{
p[i%m]=(p[i%m]+p[i])%Mod;
p[i]=0;
}
}
int main()
{
int n,m,p;
scanf("%d%d%d",&n,&m,&p);
int ans[MAXN],t[MAXN],g[MAXN];
memset(ans,0,sizeof(ans));
memset(t,0,sizeof(t));
memset(g,0,sizeof(g));
ans[0]=1;
for(int i='A'; i<='Z'; i++)
g[i%m]++;
while(n)
{
if(n&1)
{
int ck=0;
memset(t,0,sizeof(t));
for(int i=0; i<m; i++)
{
t[ck]=(t[ck]+ans[i])%Mod;
ck=(ck+p)%m;
}
multiply(t,g,m);
memcpy(ans,t,sizeof(t));
}
int ck=0;
memset(t,0,sizeof(t));
for(int i=0; i<m; i++)
{
t[ck]=(t[ck]+g[i])%Mod;
ck=(ck+p)%m;
}
multiply(g,t,m);
p=(1ll*p*p)%m;
n>>=1;
}
long long ans1=0;
for(int i=0;i<m;i++)
{
if(ans[i]<2)
ans[i]+=Mod;
ans1=(ans1+(1ll*ans[i]*(ans[i]-1)/2)%Mod)%Mod;
//cout<<i<<" "<<ans[i]<<endl;
}
printf("%lld\n",ans1);
}
最后一个美好的愿望,希望自己以后Carry多一点,sb少一点、。