Description
给一个字符串 s s s,从中选出一个前缀与一个后缀(不重叠)前后拼在一起,且拼在一起的串是回文,求最大长度的此类串
Solution
官方的做法是用一个小结论:
设前缀长度为
l
l
l,后缀长度为
r
r
r,
则
∀
i
∈
[
1
,
m
i
n
(
l
,
r
)
]
,
s
[
i
]
=
s
[
l
e
n
−
i
+
1
]
\forall i \in [1,min(l, r)],s[i] = s[len - i + 1]
∀i∈[1,min(l,r)],s[i]=s[len−i+1]
设
k
=
m
i
n
(
l
,
r
)
m
a
x
k = min(l, r)_{max}
k=min(l,r)max, 答案必有一种方案满足
m
i
n
(
l
,
r
)
=
k
min(l, r) = k
min(l,r)=k
当
m
i
n
(
l
,
r
)
<
k
min(l, r) < k
min(l,r)<k,不妨设
l
<
r
l < r
l<r,则
s
[
l
+
1
]
=
s
[
l
e
n
−
l
]
=
s
[
l
e
n
−
r
+
1
]
s[l+1]=s[len - l]=s[len - r + 1]
s[l+1]=s[len−l]=s[len−r+1]
我们可以让
r
−
−
,
l
+
+
r--,l++
r−−,l++,这样仍是长度相同的回文串
于是我们先找出
k
k
k,然后把前
k
k
k个和后
k
k
k个字符删除
再找出最长前缀回文和最长后缀回文比较选择
找出最长前缀回文最好想的是
h
a
s
h
hash
hash(用
m
a
p
map
map好像会
M
L
E
MLE
MLE)
标程是将原串与反串连在一起,中间加一个非小写字母的字符,然后跑一遍
k
m
p
kmp
kmp
当然还有更弱智的做法:
这题是多组数据,考场我用
m
a
p
map
map没清零导致爆炸,就开始质疑上面那个正确的结论
…
…
\dots\dots
……
然后就发现这是一道回文自动机裸题,求出以每个位置为结尾或开头的最长回文串,然后对于
i
<
=
k
i <= k
i<=k,它的贡献就是
2
∗
i
+
2 * i +
2∗i+以
i
+
1
i+1
i+1开头的最长回文串/以
l
e
n
−
i
len-i
len−i结尾的最长回文串,当然一旦选取最长回文串使得前后缀重叠就要跳
f
a
i
l
fail
fail链,但这样复杂度无法保证
O
(
n
)
O(n)
O(n)
于是我们不需要每次都求最长回文串,每次只要求到不导致重叠的最长回文串就行
我的两份代码:
#include <bits/stdc++.h>
using namespace std;
int T, len, a, b, nxt[2000010];
char s[1000010];
int solve(string ss) {
string t = ss;
reverse(ss.begin(), ss.end());
t = ' ' + t + '#' + ss;
//cerr << t << "-------------------\n";
int j = 0;
for (int i = 2; i < t.size(); i++) {
while (j > 0 && t[i] != t[j + 1]) j = nxt[j];
if (t[i] == t[j + 1]) j++;
nxt[i] = j;
}
return j;
}
int main() {
cin >> T;
while (T--) {
scanf ("%s", s + 1);
len = strlen(s + 1);
a = 1, b = len;
for (a = 1, b = len; a < b && s[a] == s[b]; a++, b--);
if (a >= b) {
for (int i = 1; i <= len; i++) putchar(s[i]);
putchar('\n');
continue;
}
string str = "";
for (int i = a; i <= b; i++) str += s[i];
int ans1 = solve(str);
str = "";
for (int i = b; i >= a; i--) str += s[i];
int ans2 = solve(str);
//cout << a << " " << b << " " << ans1 << " " << ans2 << endl;
if (ans1 < ans2) {
for (int i = 1; i < a; i++) putchar(s[i]);
for (int i = b - ans2 + 1; i <= len; i++) putchar(s[i]);
}
else {
for (int i = 1; i < a + ans1; i++) putchar(s[i]);
for (int i = b + 1; i <= len; i++) putchar(s[i]);
}
putchar('\n');
}
}
#include <bits/stdc++.h>
using namespace std;
int T, siz, tot, lst, fail[1000010], len[1000010], trans[1000010][26];
int ans1, ans2, pos1, pos2;
char s[1000010], t[1000010];
void Insert1(int c, int N) {
int cur = lst;
while (t[N - len[cur] - 1] != t[N]) cur = fail[cur];
int now = trans[cur][c];
if (!now) {
now = ++tot;
memset(trans[now], 0, sizeof(trans[now]));
int v = fail[cur];
while (t[N - len[v] - 1] != t[N]) v = fail[v];
len[now] = len[cur] + 2;
fail[now] = trans[v][c];
trans[cur][c] = now;
//num[now] = num[fail[now]] + 1;
}
lst = now;
}
void Insert2(int c, int N) {
int cur = lst;
while (s[N - len[cur] - 1] != s[N]) cur = fail[cur];
int now = trans[cur][c];
if (!now) {
now = ++tot;
memset(trans[now], 0, sizeof(trans[now]));
int v = fail[cur];
while (s[N - len[v] - 1] != s[N]) v = fail[v];
len[now] = len[cur] + 2;
fail[now] = trans[v][c];
trans[cur][c] = now;
//num[now] = num[fail[now]] + 1;
}
lst = now;
}
int main() {
cin >> T;
while (T--) {
scanf("%s", s + 1);
siz = strlen(s + 1);
int k;
for (k = 1; k <= siz - k && s[k] == s[siz - k + 1]; k++);
if (k > siz - k) {
for (int i = 1; i <= siz; i++) putchar(s[i]);
putchar('\n');
continue;
}
ans1 = ans2 = pos1 = pos2 = 0;
for (int i = 0; i <= siz; i++) fail[i] = 0;
fail[1] = 1, fail[0] = 1, len[1] = -1, len[0] = lst = 0, tot = 1;
memset(trans[0], 0, sizeof(trans[0]));
memset(trans[1], 0, sizeof(trans[1]));
for (int i = 1; i <= siz; i++) {
t[i] = s[siz - i + 1];
Insert1(t[i] - 'a', i);
if (i + k > siz) {
while (lst > 1 && siz - i + len[lst] > i) lst = fail[lst];
if (ans1 < 2 * (siz - i) + len[lst]) {
ans1 = 2 * (siz - i) + len[lst];
pos1 = siz - i;
}
}
}
for (int i = 0; i <= siz; i++) fail[i] = 0;
fail[1] = 1, fail[0] = 1, len[1] = -1, len[0] = lst = 0, tot = 1;
memset(trans[0], 0, sizeof(trans[0]));
memset(trans[1], 0, sizeof(trans[1]));
for (int i = 1; i <= siz; i++) {
Insert2(s[i] - 'a', i);
if (i + k > siz) {
while (lst > 1 && siz - i + len[lst] > i) lst = fail[lst];
if (ans2 < 2 * (siz - i) + len[lst]) {
ans2 = 2 * (siz - i) + len[lst];
pos2 = siz - i;
}
}
}
if (ans1 < ans2) {
for (int i = 1; i <= pos2; i++) putchar(s[i]);
for (int i = siz - (ans2 - pos2) + 1; i <= siz; i++) putchar(s[i]);
}
else {
for (int i = 1; i <= ans1 - pos1; i++) putchar(s[i]);
for (int i = siz - pos1 + 1; i <= siz; i++) putchar(s[i]);
}
putchar('\n');
}
return 0;
}