牛客-合并回文子串
原题链接
题目描述
输入两个字符串A和B,合并成一个串C,属于A和B的字符在C中顺序保持不变。如"abc"和"xyz"可以被组合成"axbycz"或"abxcyz"等。
我们定义字符串的价值为其最长回文子串的长度(回文串表示从正反两边看完全一致的字符串,如"aba"和"xyyx")。
需要求出所有可能的C中价值最大的字符串,输出这个最大价值即可
输入描述
第一行一个整数T(T ≤ 50)。
接下来2T行,每两行两个字符串分别代表A,B(|A|,|B| ≤ 50),A,B的字符集为全体小写字母。
输出描述
对于每组数据输出一行一个整数表示价值最大的C的价值。
测试样例
示例1
输入
2
aa
bb
a
aaaabcaa
输出
4
5
思路
本题可以用区间dp求解。
-
区间dp
区间dp是线性dp的扩展,它在分阶段的划分问题时,与阶段中元素出现的顺序和由前一阶段的哪些元素合并而来有很大关系。 -
分析
本题有这样一个关键信息:属于A和B的字符在C中顺序保持不变
这让我们可以对答案有一个分析:
假设,我们已经找到了所有可能的C中价值最大的子串,那么这个子串一定是由A中某个子区间和B中某个子区间排列而成
这很好理解,因为题目给了属于A和B中的字符在C中顺序保持不变,所以最终C中这个最长的回文子串,一定是从A中连续地取一些字符,从B中连续地取一些字符,以某种排列组合而成。
这让我们想到,可以枚举A的子区间长度和起点,以及B的子区间长度和起点进行求解//lena表示A子区间长度, lenb表示B子区间长度 //la是A子区间左端点,ra是右端点(可根据左端点和子区间长度得出)lb,rb同理 for (int lena = 0; lena <= na; ++lena) { for (int lenb = 0; lenb <= nb; ++lenb) { for (int la = 1, ra = la + lena - 1; ra <= na; ++la, ++ra) { for (int lb = 1, rb = lb + lenb - 1; rb <= nb; ++lb, ++rb) { } } } }
通过以上方法枚举可以得到A的一个子区间和B的一个子区间,这就已经是区间dp中的合并点,我们下一步要做的,首先要判断A、B子区间组合而成的新串是否为回文串,如果是,那么就更新答案(lena + lenb)。可以发现,更新答案是容易的,困难的是判断新串是否为回文串。所以我们可以设计出:
( b o o l ) d p [ l a ] [ r a ] [ l b ] [ r b ] (bool)dp[la][ra][lb][rb] (bool)dp[la][ra][lb][rb]
代表A[la, ra]和B[lb, rb]是否能通过某种排列得到回文串 -
状态转移
在已经找到合并点后,我们要思考,合并点的答案是由哪些部分组成的
假设枚举出A的子区间为a,B的子区间为b,a和b排列而成的新串为s,那么s的左端点和右端点只有以下四种情况:s左端点 s右端点 a左端点 a右端点 a左端点 b右端点 b左端点 b右端点 b左端点 a右端点 我们可以枚举以上四种情况,如果s左端点等于s右端点,那么问题就转化为求 s[ls + 1, rs - 1] 是否能得到回文子串,状态转移方程就出来了
if (lena > 1 && a[la] == a[ra]) dp[la][ra][lb][rb] |= dp[la + 1][ra - 1][lb][rb]; if (lenb > 1 && b[lb] == b[rb]) dp[la][ra][lb][rb] |= dp[la][ra][lb + 1][rb - 1]; if (lena != 0 && lenb != 0 && a[la] == b[rb]) dp[la][ra][lb][rb] |= dp[la + 1][ra][lb][rb - 1]; if (lena != 0 && lenb != 0 && b[lb] == a[ra]) dp[la][ra][lb][rb] |= dp[la][ra - 1][lb + 1][rb];
在完成状态转移方程后,要对dp[la][ra][lb][rb]进行一个判断,如果为真(即可以得到回文子串),那么就更新答案
a n s = m a x ( a n s , l e n a + l e n b ) ; ans=max(ans,lena+lenb); ans=max(ans,lena+lenb);
此题要注意一些边界值,加特判。
代码
#include<bits/stdc++.h>
using namespace std;
bool dp[100][100][100][100];
void solve() {
string a, b;
cin >> a >> b;
int na = a.length(), nb = b.length(), ans = 1;
a = '#' + a;
b = '#' + b;
for (int lena = 0; lena <= na; ++lena) {
for (int lenb = 0; lenb <= nb; ++lenb) {
for (int la = 1, ra = la + lena - 1; ra <= na; ++la, ++ra) {
for (int lb = 1, rb = lb + lenb - 1; rb <= nb; ++lb, ++rb) {
if (lena + lenb <= 1) {
dp[la][ra][lb][rb] = true;
} else {
dp[la][ra][lb][rb] = false;
if (lena > 1 && a[la] == a[ra]) dp[la][ra][lb][rb] |= dp[la + 1][ra - 1][lb][rb];
if (lenb > 1 && b[lb] == b[rb]) dp[la][ra][lb][rb] |= dp[la][ra][lb + 1][rb - 1];
if (lena != 0 && lenb != 0 && a[la] == b[rb]) dp[la][ra][lb][rb] |= dp[la + 1][ra][lb][rb - 1];
if (lena != 0 && lenb != 0 && b[lb] == a[ra]) dp[la][ra][lb][rb] |= dp[la][ra - 1][lb + 1][rb];
}
if (dp[la][ra][lb][rb]) {
ans = max(ans, lena + lenb);
}
}
}
}
}
cout << ans << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int t;
cin >> t;
while (t--) {
solve();
}
return 0;
}