POJ 3376 Finding Palindromes
题目大意:
给你N个字符串, 你可以两两连接得到N * N个字符串, 问之中回文串的数量. N个字符串的长度和加起来不超过2000000.
简要分析:
无比恶心的题啊...
我们顺次考虑每个字符串放在前面的情况. 假设字符串i放在前面, j放在后面, 那么这个串是回文有两种情况:
1) 若i的长度小于j, 则i是j反串的前缀, 且j反串剩下的后缀是回文串.
2) 若i的长度不小于j, 则j反串是i的前缀, 且i串剩下的后缀是回文串.
于是大致的思路就有了, 把所有串的反串丢到Trie里面, 每个结点额外记录两个值: 从这个点往下走到叶子, 有多少串是回文; 在这个点结束的字符串有多少. 这两个值就分别对应前文的两种情况了.
1) 在Trie中, 若i串在非叶子结点匹配完成, 则把该节点往下走有多少个回文串(即j反串的后缀!!!)累加到答案.
2) 在Trie中, 若在匹配i串时遇上在这个结点结束的字符串, 那么看i剩下的后缀是否是回文串, 若是, 则把在该点结束的字符串数目累加到答案.
在艰难的分析后, 问题转化成求某个串(i串和j反串)有哪些后缀是回文串. 第一反应是后缀数组, 把串和其反串连起来, 中间用奇葩字符隔开, 求一遍后缀数组, 设字符串长度为N, 则i后缀是回文串等价于i后缀与N+2后缀的LCP为N-i+1. 但是由于后缀数组巨大的空间开销和常数, 在这里用不是TLE就是MLE...于是囧了, 看来又是某种生僻算法了.
翻了下Discuss, 看到扩展KMP的字样. 百度一下发现是解决这么一个问题, 给串S和模式串T, 求S的所有后缀与T的LCP, 复杂度O(LenS + LenT). 这...不正是我们想要的吗! 令S为我们想知道哪些后缀是回文的那个串, 模式串T为其反串, 那么就看i后缀与T的LCP是否为LenS - i + 1了.
于是, 这个算法是这样的: 设下标都从0开始, S串已经处理到i后缀, ex[i]记录i后缀与T的LCP, 设i之前匹配的最远的位置是k, 则这个最远的位置p = k + ex[k] - 1. 假设我们手上还有个NX的数组next, next[i]表示T的i后缀与T的最长公共前缀, 下面开始推:
S[k..p] = T[0..p - k].
S[i..p] = T[i - k..p - k].
令next[i - k] = L, 则T[0..L - 1] = T[i - k..i - k + L - 1], S[i..i + L - 1] = T[0..L - 1].
接着我们看i + L - 1与p的关系.
1) 若i + L - 1 < p, 而p之前的位置都是枚举到过的, 所以ex[i]不会超过L, 直接ex[i] = L即可.
2) 否则, 因为p之后都是未探索的地区, 所以要往后匹配. S[i..p] = T[0..p-i], 于是令j = p - i + 1, 从S[i + j]与T[j]开始直接往后模拟匹配, 得出ex[i]值并更新k和p.
至于那个NX的next数组, 其实相当于母串S与T相同, 于是求法也就相同了, 就像KMP中的自身匹配求出pre一样.
呼~终于说完了. 只要知道了思路, 写的时候直接推导就可以了, 没必要记. 现在记不住的式子有两个了, 一个是扩展GCD, 一个是扩展KMP...
PS: 这还不是这题的恶心之处...我对最初代码大概依次有如下改动: 后缀数组改扩展KMP, TLE; string换成char数组, 用指针表示每个字符串, TLE; 去掉几个没必要的memset, TLE; 把Trie树改成邻接表用图来存, TLE...
最后我把某个int数组类型改成char, A了...
代码实现:
1 #include <cstdio> 2 #include <cstdlib> 3 #include <cstring> 4 #include <algorithm> 5 using namespace std; 6 7 const int MAX_N = 2000000; 8 char buf[MAX_N + 1], tmp[MAX_N + 1]; 9 int n, l[MAX_N], sz[MAX_N], ex[MAX_N], next[MAX_N]; 10 char s[MAX_N], t[MAX_N]; 11 long long ans = 0LL; 12 13 namespace trie { 14 const int MAX_V = MAX_N + 1, MAX_E = 10000000; 15 int ecnt, begin[MAX_V], to[MAX_E], next[MAX_E], end[MAX_V], cnt[MAX_V]; 16 char val[MAX_V]; 17 18 int node_idx, root; 19 20 void init() { 21 ecnt = 0; 22 memset(begin, -1, sizeof(begin)); 23 node_idx = 0; 24 root = node_idx ++; 25 val[root] = -1; 26 } 27 28 void add_edge(int u, int v) { 29 next[ecnt] = begin[u]; 30 begin[u] = ecnt; 31 to[ecnt ++] = v; 32 } 33 34 void ins(int sz) { 35 int pos = root; 36 for (int i = 0; i < sz; i ++) { 37 int t = s[i] - 'a'; 38 if (ex[i] == sz - i) cnt[pos] ++; 39 bool exi = 0; 40 for (int now = begin[pos]; now != -1; now = next[now]) 41 if (val[to[now]] == t) { 42 exi = 1; 43 pos = to[now]; 44 } 45 if (!exi) { 46 int v = node_idx ++; 47 add_edge(pos, v); 48 val[v] = t; 49 pos = v; 50 } 51 } 52 end[pos] ++; 53 } 54 55 void go(int sz) { 56 int pos = root; 57 for (int i = 0; i < sz; i ++) { 58 int t = s[i] - 'a'; 59 bool exi = 0; 60 for (int now = begin[pos]; now != -1; now = next[now]) 61 if (val[to[now]] == t) { 62 exi = 1; 63 pos = to[now]; 64 } 65 if (!exi) return; 66 if (end[pos]) { 67 if (i < sz - 1 && ex[i + 1] == sz - i - 1) ans += end[pos]; 68 else if (i == sz - 1) ans += end[pos]; 69 } 70 } 71 ans += cnt[pos]; 72 } 73 } 74 75 void ex_kmp(int len) { 76 //memset(ex, 0, sizeof(int) * len), memset(next, 0, sizeof(int) * len); 77 next[0] = len; 78 next[1] = len - 1; 79 for (int i = 0; i < len - 1; i ++) 80 if (t[i] != t[i + 1]) { 81 next[1] = i; 82 break; 83 } 84 int j, k = 1, p, l; 85 for (int i = 2; i < len; i ++) { 86 p = k + next[k] - 1; 87 l = next[i - k]; 88 if (i + l - 1 < p) next[i] = l; 89 else { 90 j = max(0, p + 1 - i); 91 while (i + j < len && t[i + j] == t[j]) j ++; 92 next[i] = j, k = i; 93 } 94 } 95 ex[0] = len; 96 for (int i = 0; i < len; i ++) 97 if (s[i] != t[i]) { 98 ex[0] = i; 99 break; 100 } 101 k = 0; 102 for (int i = 1; i < len; i ++) { 103 p = k + ex[k] - 1; 104 l = next[i - k]; 105 if (i + l - 1 < p) ex[i] = l; 106 else { 107 j = max(0, p + 1 - i); 108 while (i + j < len && s[i + j] == t[j]) j ++; 109 ex[i] = j, k = i; 110 } 111 } 112 } 113 114 int main() { 115 //freopen("t.in", "r", stdin); 116 scanf("%d", &n); 117 int tot = 0; 118 for (int i = 0; i < n; i ++) { 119 scanf("%d%s", &sz[i], buf + tot); 120 l[i] = tot; 121 tot += sz[i]; 122 } 123 trie::init(); 124 for (int it = 0; it < n; it ++) { 125 for (int i = 0; i < sz[it]; i ++) s[i] = buf[l[it] + sz[it] - i - 1]; 126 for (int i = 0; i < sz[it]; i ++) t[i] = buf[l[it] + i]; 127 ex_kmp(sz[it]); 128 trie::ins(sz[it]); 129 } 130 for (int it = 0; it < n; it ++) { 131 for (int i = 0; i < sz[it]; i ++) s[i] = buf[l[it] + i]; 132 for (int i = 0; i < sz[it]; i ++) t[i] = buf[l[it] + sz[it] - i - 1]; 133 ex_kmp(sz[it]); 134 trie::go(sz[it]); 135 } 136 printf("%lld\n", ans); 137 return 0; 138 }
通过这道题,我发现了一个问题,就是,我可以很轻松的想到算法,但是,编程能力还有待提高!!!!!!!