题目链接https://codeforces.com/problemset/problem/1120/C
题意
给出一串字符串,将它拆成一些子串拼接。
任何一个子串的价值都可以是
a
a
a,如果某个子串
t
i
t_i
ti是字符串
t
1
t
2
.
.
t
i
−
1
t_1t_2..t_{i-1}
t1t2..ti−1的子串,它的价值可以是
b
b
b。问价值最小是多少
题解
从后往前 d p dp dp,如果 s [ i , j ] s[i,j] s[i,j]是 s [ 1 , i − 1 ] s[1,i-1] s[1,i−1]的子串,那么 d p [ i ] = d p [ j + 1 ] + b dp[i]=dp[j+1]+b dp[i]=dp[j+1]+b。如果 s [ i , j ] s[i,j] s[i,j]的长度用 l e n len len表示,很显然 l e n len len越大越好,所以我们可以先二分一个 l e n len len,然后判断是不是子串。
判断是不是子串要先跑个后缀数组,然后也是个二分,在 r a n k [ i ] rank[i] rank[i]左右两边二分出 l e le le和 r i ri ri,使得 r a n k rank rank在 l e le le到 r i ri ri的后缀与我们当前的后缀 s u f [ i ] suf[i] suf[i]的最长公共前缀都大于等于 l e n len len。最后再查询一下 l e 到 r i le到ri le到ri的最左左端点,判断是否与 s u f [ i ] suf[i] suf[i]重叠。
这些乱七八糟的最值都可以用ST表预处理出来 O ( 1 ) O(1) O(1)查询,最后总的复杂度是 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)。
注意二分最长公共前缀的边界问题
妈哎,好像做复杂了
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+7;
int wa[N],wb[N],wv[N],we[N];
int cmp(int *r,int a,int b,int l){return r[a]==r[b]&&r[a+l]==r[b+l];}
void da(const char r[],int sa[],int n,int m){
int i,j,p,*x=wa,*y=wb,*t;
for(i=0;i<m;i++) we[i]=0;
for(i=0;i<n;i++) we[x[i]=r[i]]++;
for(i=1;i<m;i++) we[i]+=we[i-1];
for(i=n-1;i>=0;i--) sa[--we[x[i]]]=i;
for(j=1,p=1;p<n;j*=2,m=p){
for(p=0,i=n-j;i<n;i++) y[p++]=i;
for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(i=0;i<n;i++) wv[i]=x[y[i]];
for(i=0;i<m;i++) we[i]=0;
for(i=0;i<n;i++) we[wv[i]]++;
for(i=1;i<m;i++) we[i]+=we[i-1];
for(i=n-1; i>=0; i--) sa[--we[wv[i]]]=y[i];
for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1; i<n; i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
return;
}
int sa[N],rk[N],ht[N];
void calh(const char *r,int *sa,int n){
int i,j,k=0;
for(i=1;i<=n;i++) rk[sa[i]]=i;
for(i=0;i<n;ht[rk[i++]]=k)
for(k?k--:0,j=sa[rk[i]-1];r[i+k]==r[j+k];k++);
for(int i=n;i>=1;--i) ++sa[i],rk[i]=rk[i-1];
}
int n,a,b;
char s[N];
int dp[N],lg[N];
int mi[N][30],le[N][30];
int qm(int l,int r){
int k=lg[r-l+1];
return min(mi[l][k],mi[r-(1<<k)+1][k]);
}
int ql(int l,int r){
int k=lg[r-l+1];
return min(le[l][k],le[r-(1<<k)+1][k]);
}
bool ck(int len,int x){
int lo=2,hi=x,le=x;
while(lo<=hi){
int mid=lo+hi>>1;
if(qm(mid,x)>=len) hi=mid-1,le=mid-1;
else lo=mid+1;
}
lo=x+1,hi=n;int ri=x;
while(lo<=hi){
int mid=lo+hi>>1;
if(qm(x+1,mid)>=len) lo=mid+1,ri=mid;
else hi=mid-1;
}
return ql(le,ri)+len-1<sa[x];
}
int main()
{
for(int i=1;i<N;i++) lg[i]=log2(i);
scanf("%d%d%d",&n,&a,&b);
scanf("%s",s);
da(s,sa,n+1,130);
calh(s,sa,n);
for(int i=1;i<=n;i++){
mi[i][0]=ht[i];
le[i][0]=sa[i];
}
for(int j=1;j<22;j++){
for(int i=1;i+(1<<j)-1<=n;i++){
mi[i][j]=min(mi[i][j-1],mi[i+(1<<(j-1))][j-1]);
le[i][j]=min(le[i][j-1],le[i+(1<<(j-1))][j-1]);
}
}
for(int i=n;i>=1;i--){
dp[i]=dp[i+1]+a;
int lo=1,hi=n-i+1,res=-1;
while(lo<=hi){
int mid=lo+hi>>1;
if(ck(mid,rk[i])) res=mid,lo=mid+1;
else hi=mid-1;
}
if(res!=-1) dp[i]=min(dp[i],dp[i+res]+b);
}
printf("%d\n",dp[1]);
}