题意:子串的长度*子串出现的次数为这个子串的值,求最大值。还有个约束条件串t,若t[i]=='1',则以该位置结尾的子串不能使用
题解:
把串倒转过来,这样以该位置结尾不能用的条件转化为以该位置开头的子串不能用
若先不考虑约束条件,该题的答案就是某一子串的长度*该子串出现的次数
而出现次数非常好求,根据做完后缀数组后的height[]数组来。
l[i]为i左边第一个小于height[i]的位置,r[i]为i右边第一个小于height[i]的位置,出现次数就是r[i]-l[i];
l,r数组用单调栈O(n)的就能求出来。现在考虑约束条件,根据串t,用sum[]前缀和求出sa数组的‘1’的总个数,为‘1’的串不能用。最后出现次数减去就是l[i]到r[i]的1的个数(不能用的串的个数)就行
#include<bits/stdc++.h>
#define rint register int
#define inv inline void
#define ini inline int
#define maxn 2000050
typedef long long ll;
using namespace std;
char s[maxn],t[maxn];
int y[maxn],x[maxn],c[maxn],sa[maxn],rk[maxn],height[maxn],wt[30];
int n,m;
inv putout(int x) {
if(!x) {
putchar(48);
return;
}
rint l=0;
while(x) wt[++l]=x%10,x/=10;
while(l) putchar(wt[l--]+48);
}
inv get_SA() {
for (rint i=1; i<=n; ++i) ++c[x[i]=s[i]];
for (rint i=2; i<=m; ++i) c[i]+=c[i-1];
for (rint i=n; i>=1; --i) sa[c[x[i]]--]=i;
for (rint k=1; k<=n; k<<=1) {
rint num=0;
for (rint i=n-k+1; i<=n; ++i) y[++num]=i;
for (rint i=1; i<=n; ++i) if (sa[i]>k) y[++num]=sa[i]-k;
for (rint i=1; i<=m; ++i) c[i]=0;
for (rint i=1; i<=n; ++i) ++c[x[i]];
for (rint i=2; i<=m; ++i) c[i]+=c[i-1];
for (rint i=n; i>=1; --i) sa[c[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);
x[sa[1]]=1;
num=1;
for (rint i=2; i<=n; ++i)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k]) ? num : ++num;
if (num==n) break;
m=num;
}
}
inv get_height() {
rint k=0;
for (rint i=1; i<=n; ++i) rk[sa[i]]=i;
for (rint i=1; i<=n; ++i) {
if (rk[i]==1) continue;//第一名height为0
if (k) --k;//h[i]>=h[i-1]-1;
rint j=sa[rk[i]-1];
while (j+k<=n && i+k<=n && s[i+k]==s[j+k]) ++k;
height[rk[i]]=k;//h[i]=height[rk[i]];
}
}
int sum[maxn],st[maxn],l[maxn],r[maxn];
int main() {
scanf("%d",&n);
scanf("%s%s",s+1,t+1);
reverse(s+1,s+n+1);
reverse(t+1,t+n+1);
m=200;
get_SA();
get_height();
for(int i=1;i<=n;i++){
sum[i]+=sum[i-1];
if(t[sa[i]]=='1') sum[i]++;
//printf("@%d\n",sum[i]);
}
ll ans=0;
for(int i=1;i<=n;i++){
if(t[i]=='0'){
ans=n-i+1;
break;
}
}
int tail=0;
for(int i=1;i<=n;i++){
while(tail>0&&height[st[tail]]>height[i]){
r[st[tail]]=i;
tail--;
}
st[++tail]=i;
}
while(tail>0){
r[st[tail]]=n+1;
tail--;
}
tail=0;
for(int i=n;i>=1;i--){
while(tail>0&&height[st[tail]]>height[i]){
l[st[tail]]=i;
tail--;
}
st[++tail]=i;
}
while(tail>0){
l[st[tail]]=0;
tail--;
}
for(int i=1;i<=n;i++){
//printf("!%d %d %d %d\n",l[i],r[i],height[i],(sum[r[i]-1]-sum[l[i]]));
ll tmp=r[i]-l[i];
tmp-=(sum[r[i]-1]-sum[l[i]-1]);
ans=max(ans,height[i]*tmp);
}
printf("%lld\n",ans);
}