考虑一个只包含小写拉丁字母的字符串s。我们定义s的一个子串t的“出现值”为t在s中的出现次数乘以t的长度。请你求出s的所有回文子串中的最大出现值。
Manacher+SAM
注意clone的复制的元素不能包括size。
坑:PAM
#include
#include
#include
#include
#define maxn 600010
using namespace std;
char s[maxn];
int n;
long long ans;
struct Node{int len, link, nxt[26], size;}st[maxn];
int root, size, last;
void init(){
root = size = last = 0;
st[root].len = 0;
st[root].link = -1;
}
int anc[maxn][20], pos[maxn];
void Extend(char ch, int part){
int p, cur = ++ size, c = ch - 'a';
st[cur].len = st[last].len + 1;
st[cur].size = 1;
pos[part] = cur;
for(p = last; ~p && !st[p].nxt[c]; p = st[p].link)
st[p].nxt[c] = cur;
pos[part] = cur;
st[cur].size = 1;
if(p == -1)
st[cur].link = root;
else{
int q = st[p].nxt[c];
if(st[q].len == st[p].len + 1)
st[cur].link = q;
else{
int clone = ++ size;
st[clone] = st[q];
st[clone].len = st[p].len + 1;
st[clone].size = 0;
for(; ~p && st[p].nxt[c] == q; p = st[p].link)
st[p].nxt[c] = clone;
st[q].link = st[cur].link = clone;
}
}
last = cur;
}
int t[maxn], w[maxn];
void build(){
memset(anc, -1, sizeof anc);
for(int i = 1; i <= size; i ++)
anc[i][0] = st[i].link;
for(int j = 1; 1 << j <= size; j ++)
for(int i = 1; i <= size; i ++){
int a = anc[i][j - 1];
if(~a)anc[i][j] = anc[a][j - 1];
}
for(int i = 1; i <= size; i ++)
w[st[i].len] ++;
for(int i = 1; i <= size; i ++)
w[i] += w[i - 1];
//for(int i = 1; i <= size; i ++)
for(int i = size; i >= 1; i --)
t[w[st[i].len] --] = i;
for(int i = size; i; i --)
st[st[t[i]].link].size += st[t[i]].size;
}
void update(int l, int r){
int t = pos[r];
for(int i = 18; i >= 0; i --){
if(~anc[t][i]){
int to = anc[t][i];
if(st[to].len >= r - l + 1)
t = to;
}
}
ans = max(ans, 1ll * st[t].size * (r - l + 1));
}
int r[maxn];
void solve(){
s[0] = '*';
s[n + 1] = '#';
init();
for(int i = 1; i <= n; i ++)
Extend(s[i], i);
build();
int mx = 0, p = 0;
for(int i = 1; i <= n; i ++){
if(i < mx)r[i] = min(r[2 * p - i - 1], mx - i);
else r[i] = 0;
while(s[i + r[i] + 1] == s[i - r[i]]){
r[i] ++;
update(i - r[i] + 1, i + r[i]);
}
if(r[i] + i > mx)mx = r[i] + i, p = i;
}
mx = 0, p = 0;
for(int i = 1; i <= n; i ++){
if(i < mx){r[i] = min(r[2 * p - i], mx - i - 1);}
else {r[i] = 1;update(i, i);}
while(s[i + r[i]] == s[i - r[i]]){
r[i] ++;
update(i - r[i] + 1, i + r[i] - 1);
}
if(r[i] + i > mx)mx = r[i] + i, p = i;
}
printf("%lld\n", ans);
}
int main(){
scanf("%s", s + 1);
n = strlen(s + 1);
solve();
return 0;
}
[upd:PAM]
#include
#define maxn 300010
using namespace std;
int n, s[maxn];
char str[maxn];
struct Node{
int fail, nxt[26], len;
void clear(){
memset(nxt, 0, sizeof nxt);
len = fail = 0;
}
}st[maxn];
int last, size, len;
void init(){
last = len = 0;
st[0].clear(), st[1].clear();
st[0].len = 0, st[1].len = -1;
st[0].fail = st[1].fail = 1;
s[0] = -1; size = 1;
}
int get_fail(int x){
while(s[len-st[x].len-1] != s[len])
x = st[x].fail;
return x;
}
int sz[maxn];
void Extend(int c){
s[++ len] = c;
int cur = get_fail(last);
if(!st[cur].nxt[c]){
int now = ++ size;st[now].clear();
st[now].len = st[cur].len + 2;
st[now].fail = st[get_fail(st[cur].fail)].nxt[c];
st[cur].nxt[c] = now;
}
last = st[cur].nxt[c];
sz[last] ++;
}
void solve(){
long long ans = 0;
for(int i = size; i; i --){
sz[st[i].fail] += sz[i];
ans = max(ans, (long long)sz[i] * st[i].len);
}
printf("%lld\n", ans);
}
int main(){
init();
scanf("%s", str+1);
n = strlen(str+1);
for(int i = 1; i <= n; i ++)
Extend(str[i]-'a');
solve();
return 0;
}