题意
给定一个字符串,每一个后缀
i
…
n
i \dots n
i…n 有一个权值
w
i
w_i
wi。
求
m
a
x
{
L
C
P
(
i
,
j
)
+
w
i
⊕
w
j
}
,
(
i
≠
j
)
max\{LCP(i,j)+ w_i\oplus w_j\},(i \neq j)
max{LCP(i,j)+wi⊕wj},(i=j)。
题解
考虑先把 s a m sam sam 建出来(这里的 s a m sam sam 要反着建,因为要表示前缀的后缀)。
然后考虑 s a m sam sam 上的一个节点 i i i 的 e n d p o s endpos endpos 集合,他们所对应的后缀节点是确定的,所以他们之间的 L C P LCP LCP 就是 l o n g e s t ( i ) longest(i) longest(i)。所以我们只需要考虑他们之间两两异或的最大值即可。
对后缀自动机上每个节点维护一颗 t r i e trie trie 树,考虑在合并时统计答案(因为儿子的 e n d p o s endpos endpos 集合大小越少,匹配长度越长),所以我们只需要考虑不同儿子之间的异或最大值即可。用启发式合并 + s e t +set +set 维护。
代码
#include<bits/stdc++.h>
using namespace std;
const int N=5e5+10;
int w[N],n,m,ans;
char c[N];
int rt[N<<5],lc[N<<5],rc[N<<5],sum[N<<5],sz;
set<int> e[N];
void insert(int &p,int x,int dep){
if(!p)p=++sz;
sum[p]++;
if(dep<0)return;
if(x&(1<<dep))insert(rc[p],x,dep-1);
else insert(lc[p],x,dep-1);
}
int query(int p,int x,int dep){
if(dep<0)return 0;
int k=(x&(1<<dep));
if(k){
if(sum[lc[p]])return (1<<dep)+query(lc[p],x,dep-1);
return query(rc[p],x,dep-1);
}
else{
if(sum[rc[p]])return (1<<dep)+query(rc[p],x,dep-1);
return query(lc[p],x,dep-1);
}
}
struct SAM{
int now,fa[N],len[N],t[N][26],las;
int tot,ver[N],fst[N],nxt[N];
inline void add(int x,int y){ver[++tot]=y;nxt[tot]=fst[x];fst[x]=tot;}
inline void jia(int c,int x,int id){
int k=las;las=++now;len[now]=len[k]+1;
e[las].insert(id);insert(rt[las],x,20);
while(k&&!t[k][c])t[k][c]=now,k=fa[k];
if(!k)fa[now]=1;
else{
int q=t[k][c];
if(len[q]==len[k]+1)fa[now]=q;
else{
fa[++now]=fa[q];fa[q]=fa[now-1]=now;len[now]=len[k]+1;
for(int j=0;j<26;++j)t[now][j]=t[q][j];
while(k&&t[k][c]==q)t[k][c]=now,k=fa[k];
}
}
}
inline void build(){for(int i=2;i<=now;++i)add(fa[i],i);}
void dfs(int x){
for(int i=fst[x];i;i=nxt[i]){
int y=ver[i];
dfs(y);
if(e[x].size()<e[y].size())swap(e[x],e[y]),swap(rt[x],rt[y]);
for(auto it=e[y].begin();it!=e[y].end();it++){
int k=*it;
ans=max(ans,len[x]+query(rt[x],w[k],20));
}
for(auto it=e[y].begin();it!=e[y].end();it++){
int k=*it;
e[x].insert(k);insert(rt[x],w[k],20);
}
}
}
}T;
int main(){
scanf("%d",&n);
scanf("%s",c+1);
T.las=T.now=1;
for(int i=1;i<=n;++i)scanf("%d",&w[i]);
for(int i=n;i;--i)T.jia(c[i]-'a',w[i],i);
T.build();T.dfs(1);
printf("%d\n",ans);
return 0;
}