题目
内容
给出一个字符串s,并规定某些字符合法 某些不合法.
求 ∣ a ∣ ∗ f ( a ) |a|*f(a) ∣a∣∗f(a) 的最大值, a a a为s的子串, f ( a ) f(a) f(a) 为以合法字符结尾的出现次数.
分析
建后缀数组,按height数组从大到小合并,并查集维护.
由于以非法字符结尾的子串不能计算在内,而后缀数组不能很方便确定子串的结尾是否合法,因此我们先将字符串s翻转,这样就将问题转化为子串开头是否合法.
我们初始化并查集时,将合法字符开头的后缀size设为1,非法字符开头的后缀设为0.之后将height数组排序,然后从大到小合并,每次更新答案.
另外只出现一次的子串需要特判.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
#define debug(x) cerr<<#x<<' '<<x<<'\n'
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define per(i,a,b) for(int i=(a);i>=(b);i--)
const int maxn=2e5+10;
const int mod=1e9+7;
const int inf=0x3f3f3f3f;
const int maxbit=20;
struct SuffixArray
{
int sa[maxn], rank[maxn], ws[maxn], wv[maxn], wa[maxn], wb[maxn], height[maxn], st[maxbit][maxn], N;
bool cmp(int *r, int a, int b, int l){return r[a]==r[b] and r[a+l]==r[b+l];}
void build(int *r, int n, int m)
{
N=n;
n++;
int i, j, k=0, p, *x=wa, *y=wb, *t;
for(i=0;i<m;i++)ws[i]=0;
for(i=0;i<n;i++)ws[x[i]=r[i]]++;
for(i=1;i<m;i++)ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--)sa[--ws[x[i]]]=i;
for(p=j=1;p<n;j<<=1,m=p)
{
for(p=0,i=n-j;i<n;i++)y[p++]=i;
for(i=0;i<n;i++)if(sa[i]>=j)y[p++]=sa[i]-j;
for(i=0;i<n;i++)wv[i]=x[y[i]];
for(i=0;i<m;i++)ws[i]=0;
for(i=0;i<n;i++)ws[wv[i]]++;
for(i=1;i<m;i++)ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--)sa[--ws[wv[i]]]=y[i];
for(t=x,x=y,y=t,p=1,i=1,x[sa[0]]=0;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
for(i=0;i<n;i++)rank[sa[i]]=i;
for(i=0;i<n-1;height[rank[i++]]=k)
for(k?k--:0,j=sa[rank[i]-1];r[i+k]==r[j+k];k++);
}
void build_st() //st表
{
int i, k;
for(i=1;i<=N;i++)st[0][i]=height[i];
for(k=1;k<=maxbit;k++)
for(i=1;i+(1<<k)-1<=N;i++)
st[k][i]=min(st[k-1][i],st[k-1][i+(1<<k)-1]);
}
int lcp(int x, int y) //最长公共前缀
{
int l=rank[x], r=rank[y];
if(l>r)swap(l,r);
if(l==r)return N-sa[l];
int t=log2(r-l);
return min(st[t][l+1],st[t][r-(1<<t)+1]);
}
}SA;
int strcp[maxn],fa[maxn];
ll siz[maxn];
bool can[maxn];
pii pos[maxn];
int Find(int x) {
if(x==fa[x]) return x;
else return fa[x]=Find(fa[x]);
}
void Union(int a,int b) {
int faa=Find(a),fab=Find(b);
siz[faa]+=siz[fab];
fa[fab]=faa;
}
int main()
{
ios::sync_with_stdio(false);cin.tie(0);
int len;
cin>>len;
string a,b;
cin>>a;
cin>>b;
rep(i,0,len-1) {
strcp[len-i-1]=a[i];
if(b[i]=='0') can[len-i-1]=true;
else can[len-i-1]=false;
}
ll ans=0;
rep(i,0,len) {
fa[i]=i;
if(can[i]) {
siz[i]=1;
ans=max(ans,(ll)len-i);
}
}
strcp[len]=0;
SA.build(strcp,len,300);
int num=0;
rep(i,2,len) {
pos[num].fi=SA.height[i];
pos[num++].se=i;
}
sort(pos,pos+num);
per(i,num-1,0) {
int l=SA.sa[pos[i].se-1],r=SA.sa[pos[i].se];
l=Find(l),r=Find(r);
ans=max(ans,(siz[l]+siz[r])*pos[i].fi);
Union(l,r);
}
cout<<ans<<'\n';
return 0;
}