题面
题目链接
https://ac.nowcoder.com/acm/problem/14894
题目大意
有两个长度均为n的字符串A和B。
可以从A中选一个可以为空的子串A[l1..r1],B中选一个可以为空的子串B[l2..r2]
需要满足r1 = l2,然后把它们拼起来(A[l1..r1]+B[l2..r2])
求用这样的方法能得到的最长回文串 S 的长度。
解题思路
首先对 A , B 串都跑一边 Manacher,分别得到 PA , PB 和处理过的 A , B
然后再枚举处理过的 A / B,以每个字符作为 S 的中心点,半径为 max(PA , PB) 进行拓展即可
因为以 i 为中心 PA / PB 为半径的回文串中,它的最小回文长度就达到了 PA / PB - 1
所以只要从 PA / PB 两端拓展看还能找到多少可以相匹配的字符就可以了
而假设 S 的中心点为 i , A 提供的是 i - 1 , i - 2 ... ,B 提供的是 i + 1 , i + 2 ...,他们之间相差了 2
所以当枚举到 A 的第 i 个字符时,需要操作的是 PA[ i ] 和 PB[ i - 2 ]
而拓展的方法有两种,一种是逐一匹配,俗称 brute force
另一种是 二分拓展的长度 + hash check 来匹配
第一种做法的复杂度我不太会算,感觉会超时但才跑了500ms?
而第二种做法显然要快上不少,大概是 50ms
这里提供两种做法
AC_Coder_(暴力)
#include<bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10;
string a , b;
int pa[N] , pb[N] , res = 1;
string Manacher(string a , int *p)
{
string t = "$#";
for(auto i : a) t += i , t += '#';
int mx = 0 , id = 0 ;
int len = t.size() , ans = 0;
for(int i = 1 ; i < len ; i ++)
{
p[i] = mx > i ? min(p[2 * id - i] , mx - i) : 1;
while(t[i + p[i]] == t[i - p[i]]) p[i] ++ ;
if(mx < i + p[i]) mx = i + p[i] , id = i;
ans = max(ans , p[i] - 1);
}
res = max(res , ans);
return t;
}
signed main()
{
int n ;
cin >> n >> a >> b;
a = Manacher(a , pa) , b = Manacher(b , pb);
n = n * 2 + 2;
int ans = 1;
for(int i = 2 ; i <= n ; i ++)
{
int len = max(pa[i] , pb[i - 2]);
while(a[i - len] == b[i - 2 + len]) len ++;
ans = max(ans , len - 1);
}
cout << ans << '\n';
return 0;
}
AC_Coder_(hash + 二分)
#include<bits/stdc++.h>
#define int long long
#define ull unsigned long long
using namespace std;
const int N = 3e5 + 10;
const int MOD = 999998639;
const int P = 13331;
string a , b;
int pa[N] , pb[N] , res = 1;
ull pre[N] , suf[N] , power[N];
string Manacher(string a , int *p)
{
string t = "$#";
for(auto i : a) t += i , t += '#';
int mx = 0 , id = 0 ;
int len = t.size() , ans = 0;
for(int i = 1 ; i < len ; i ++)
{
p[i] = mx > i ? min(p[2 * id - i] , mx - i) : 1;
while(t[i + p[i]] == t[i - p[i]]) p[i] ++ ;
if(mx < i + p[i]) mx = i + p[i] , id = i;
ans = max(ans , p[i] - 1);
}
res = max(res , ans);
return t;
}
ull get_hash1(int l , int r)
{
if(l > r) return -999;
return (pre[r] - pre[l - 1] * power[r - l + 1] % MOD + MOD) % MOD;
}
ull get_hash2(int l , int r)
{
if(l > r) return -888;
return (suf[l] - suf[r + 1] * power[r - l + 1] % MOD + MOD) % MOD;
}
void init(int n)
{
power[0] = 1;
for(int i = 1 ; i < N - 5 ; i ++) power[i] = power[i - 1] * P % MOD;
pre[0] = a[0] ;
for(int i = 1 ; i < n ; i ++) pre[i] = (pre[i - 1] * P + a[i]) % MOD ;
suf[n - 1] = b[n - 1];
for(int i = n - 2 ; i >= 0 ; i --) suf[i] = (suf[i + 1] * P + b[i]) % MOD;
}
signed main()
{
int n ;
cin >> n >> a >> b;
a = Manacher(a , pa) , b = Manacher(b , pb);
n = n * 2 + 2;
init(n);
int ans = 1;
for(int i = 2 ; i <= n ; i ++)
{
int add = 0;
int len = max(pa[i] , pb[i - 2]);
int sa = i - len + 1 , sb = i - 2 + len - 1;
int l = 0 , r = min(sa - 1 , n - sb);
while(l <= r)
{
int mid = l + r >> 1;
int l1 = max(0LL , sa - mid) , r1 = min(n , sa - 1);
int l2 = max(sb + 1 , 0LL) , r2 = min(n , sb + mid);
if(get_hash1(l1 , r1) == get_hash2(l2 , r2))
l = mid + 1 , add = mid;
else r = mid - 1;
}
ans = max(ans , add + len - 1);
}
cout << ans << '\n';
return 0;
}