朴素算法n^2求最长公共后缀
#include <bits/stdc++.h>
using namespace std;
const int N=5e3+5;
int n,a,b;
int f[N],g[N][N];
char str[N];
int main(){
scanf("%d%d%d",&n,&a,&b);
scanf("%s",str+1);
for (register int i=1; i<n; ++i)
for (register int j=i+1; j<=n; ++j)
if (str[i]==str[j]) g[i][j]=g[i-1][j-1]+1;
else g[i][j]=0;
memset(f,60,sizeof(f));
f[1]=a;
for (register int i=2; i<=n; ++i)
{
f[i]=f[i-1]+a;
for (register int j=1; j<i; ++j)
if (j<i-g[j][i]+1) f[i]=min(f[i],f[i-g[j][i]]+b);
}
printf("%d\n",f[n]);
return 0;
}
现在我们可以通过SA来完成这项操作。同时用st表维护height数组,得到答案。
#include <bits/stdc++.h>
using namespace std;
const int N=5e3+5;
int n,m,a,b;
int sum[N],rk[N],rk2[N],tp[N],sa[N],height[N];
char s[N];
int LOG[N],minn[N][14],f[N];
inline void qsort()
{
for (register int i=0; i<=m; ++i) sum[i]=0;
for (register int i=1; i<=n; ++i) sum[rk[i]]++;
for (register int i=1; i<=m; ++i) sum[i]+=sum[i-1];
for (register int i=n; i>=1; --i) sa[sum[rk[tp[i]]]]=tp[i],sum[rk[tp[i]]]--;
}
inline void SA()
{
m=130;
for (register int i=1; i<=n; ++i) rk[i]=s[i]-'0'+1,tp[i]=i;
qsort();
int p=0;
for (register int len=1; p<n; m=p,len<<=1)
{
p=0;
for (register int i=n-len+1; i<=n; ++i) tp[++p]=i;
for (register int i=1; i<=n; ++i) if (sa[i]>len) tp[++p]=sa[i]-len;
qsort();
memcpy(rk2,rk,sizeof(rk2));
p=1; rk[sa[1]]=1;
for (register int i=2; i<=n; ++i)
{
if (rk2[sa[i]]==rk2[sa[i-1]] && rk2[sa[i]+len]==rk2[sa[i-1]+len]) p=p; else p++;
rk[sa[i]]=p;
}
}
}
inline void LCP()
{
for (register int i=1; i<=n; ++i) rk[sa[i]]=i;
int k=0;
for (register int i=1; i<=n; ++i)
{
if (rk[i]==1) continue;
if (k) k--;
int j=sa[rk[i]-1];
while (i+k<=n && j+k<=n && s[i+k]==s[j+k]) k++;
height[rk[i]]=k;
}
}
inline void ST()
{
LOG[0]=-1;
for (register int i=1; i<=n; ++i) LOG[i]=LOG[i>>1]+1;
for (register int i=1; i<=n; ++i) minn[i][0]=height[i];
for (register int j=1; j<=13; ++j)
for (register int i=1; i+(1<<j)-1<=n; ++i)
{
minn[i][j]=min(minn[i][j-1],minn[i+(1<<(j-1))][j-1]);
}
}
inline void dp()
{
memset(f,60,sizeof(f));
f[0]=0;
for (register int i=0; i<n; ++i)
{
f[i+1]=min(f[i+1],f[i]+a);
int maxn=0;
for (register int j=1; j<=i; ++j)
{
int x=rk[j],y=rk[i+1];
if (x>y) swap(x,y);
int s=LOG[y-(x+1)+1];
maxn=max(maxn,min(min(minn[x+1][s],minn[y-(1<<s)+1][s]),i-j+1));
}
f[i+maxn]=min(f[i+maxn],f[i]+b);
}
}
int main(){
scanf("%d%d%d",&n,&a,&b);
scanf("%s",s+1);
SA();
LCP();
ST();
dp();
printf("%d\n",f[n]);
return 0;
}