回文子串
题目背景:
11.02 NOIP模拟T3
分析:DP 记忆化搜索
一开始看这个题,认认真真的想了很久最后发现:我还是暴力吧···后来才知道这玩意儿是DP,我们来定义f[sl][tl][sr][tr]表示,最终的串的前(sl + tl)位已经用s串的前sl位,和t串的前tl位填上了,最后的(lens - sr + lent - tr)位被s串的sr ~ lens位和t串的tr ~ lent位填上了的最优的长度是多少,那么考虑如何更新这个状态。
f[sl - 1][tl][sr + 1][tr] + 2 à f[sl][tl][sr][tr] (sl < sr&& s[sl] == s[sr])
f[sl][tl - 1][sr][tr + 1] + 2 à f[sl][tl][sr][tr] (tl < tr&& t[tl] == t[tr])
f[sl - 1][tl - (tl == tr)][sr + (sl== sr)][tr + 1] + 2 à f[sl][tl][sr][tr] (sl <= sr&& tl <= tr && s[sl] == t[tr])
f[sl - (sl == sr)][tl - 1][sr + 1][tr+ (tl == tr)] + 2 à f[sl][tl][sr][tr] (sl <= sr&& tl <= tr && t[tl] == s[sr])
边界就是f[0][0][lens][lent] = 0这个DP的状态用for循环来写比较麻烦,所以我是直接选择了记忆化搜索,比for循环快了很多,应该是无用状态比较少的原因了。注意到我们目前只考虑了偶数回文串的情况,考虑奇数回文串的状态就是,有一个串已经全部用完,另一个串还剩一个字符的时候,在中间插上最后一个字符,而偶数回文串的情况就是直接将两个字符串用完的情况了,复杂度O(n4)
Source:
/*
created by scarlyw
*/
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <string>
#include <cstring>
#include <cctype>
#include <vector>
#include <queue>
#include <set>
#include <ctime>
const int MAXN = 50 + 10;
char s[MAXN], t[MAXN];
int lens, lent, ans;
int f[MAXN][MAXN][MAXN][MAXN];
inline int dfs(int sl, int tl, int sr, int tr) {
if (sl == 0 && tl == 0 && sr == lens + 1 && tr == lent + 1) return 0;
if (~f[sl][tl][sr][tr]) return f[sl][tl][sr][tr];
int ret = 0;
if (sl != 0 && sr != lens + 1 && sl < sr && s[sl] == s[sr])
ret = std::max(ret, dfs(sl - 1, tl, sr + 1, tr) + 2);
if (tl != 0 && tr != lent + 1 && tl < tr && t[tl] == t[tr])
ret = std::max(ret, dfs(sl, tl - 1, sr, tr + 1) + 2);
if (sl != 0 && sl != lens + 1 && tr != 0 &&
tr != lent + 1 && sl <= sr && tl <= tr && s[sl] == t[tr])
ret = std::max(ret, dfs(sl - 1, tl - (tl == tr),
sr - (sl == sr), tr + 1) + 2);
if (sr != 0 && sr != lens + 1 && tl != 0 &&
tl != lent + 1 && sl <= sr && tl <= tr && t[tl] == s[sr])
ret = std::max(ret, dfs(sl - (sl == sr), tl - 1,
sr + 1, tr + (tl == tr)) + 2);
return f[sl][tl][sr][tr] = ret;
}
inline void solve() {
scanf("%s", s + 1), scanf("%s", t + 1);
lens = strlen(s + 1), lent = strlen(t + 1), ans = 0;
memset(f, -1, sizeof(f));
for (int i = 1; i <= lens; ++i)
for (int j = 1; j <= lent; ++j) {
ans = std::max(ans, dfs(i, j, i + 1, j + 1));
ans = std::max(ans, dfs(i, j, i + 2, j + 1) + 1);
ans = std::max(ans, dfs(i, j, i + 1, j + 2) + 1);
}
std::cout << ans;
}
int main() {
solve();
return 0;
}