链接的G题: http://codeforces.com/gym/100548
1. 由于不会回文树, 看到题目感觉很SAM
2. 仔细一想可以发现 SAM 中一个节点所代表的字符串最多只有一个是回文串
以同一个字母结尾的不同回文串 出现的位置不可能完全相同
说明一个长为n的串种出现的回文子串种类不超过n种
3.在新建一个结点的时候如何判断——该节点所存储的字符串中是否有回文串呢?
如果有的话,必然是以当前字符结尾的最长的那个回文串
Manacher处理一下 最长的有多长。 然后比较一下是否大于等于minlen[cur] 就行了
4.在匹配到SAM上一个节点时具体要加多少呢?
假设当前匹配到的字符串长度为len
那要加的就是比len短的回文串在模式串中出现的次数
具体就是suffix_link上去找, 记得乘上其出现的次数
5. 顺便回头看了一下, 好像回文树能做的, SAM + manacher 都能做?
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <set>
#include <map>
#include <queue>
#include <algorithm>
#include <stack>
#include <cctype>
#include <cmath>
#include <vector>
#include <sstream>
#include <bitset>
#include <deque>
#include <iomanip>
using namespace std;
#define pr(x) cout << #x << " = " << x << endl;
#define bug cout << "bugbug" << endl;
#define ppr(x, y) printf("(%d, %d)\n", x, y);
#define MST(a,b) memset(a,b,sizeof(a))
#define CLR(a) MST(a,0)
#define SQR(a) ((a)*(a))
#define PCUT puts("\n---------------")
typedef long long ll;
typedef double DBL;
typedef pair<int, int> P;
typedef unsigned int uint;
const int MOD = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const int maxn = 4e5 + 4;
const int maxm = 1e3 + 4;
const double pi = acos(-1.0);
int m[maxn], M[maxn], Temp[maxn];
char buf[maxn];
void manacher(char* s){
int j = 1;
buf[0] = '#';
for (int i = 0; s[i]; ++i){
buf[j++] = '$';
buf[j++] = s[i];
}
buf[j++] = '$';
buf[j++] = 0;
int center = -1, maxv = -1;
for (int i = 1; buf[i]; ++i){
m[i] = i <= maxv ? min(m[2*center-i], maxv - i + 1) : 1;
while(buf[i-m[i]] == buf[i+m[i]]) m[i]++;
if (i + m[i] - 1 > maxv){
maxv = i + m[i] - 1;
center = i;
}
}
int lst = 0;
for (int i = 1; buf[i]; ++i)
if (i + m[i] - 1 > lst) while (i + m[i] - 1 > lst){
lst++;
Temp[lst] = (lst - i) * 2 + 1;
}
for (int i = 2; buf[i]; i += 2){
M[i/2-1] = (Temp[i] + 1) / 2;
}
return;
}
stack<int> sta;
struct SuffixAutoMachine{
int trans[maxn][26], maxlen[maxn], link[maxn], cnt, lst;//写成了我手上AC自动机模板那个风格
int times[maxn], in[maxn];
ll plus[maxn], len_p[maxn];//当前节点的回文串长度
inline int id(char x){
return x - 'a';
}
void newNode(){
cnt++;
memset(trans[cnt], -1, sizeof trans[cnt]);
link[cnt] = len_p[cnt] = -1;
times[cnt] = plus[cnt] = 0;
return;
}
void insert(int id){
newNode();
maxlen[cnt] = maxlen[lst] + 1;
times[cnt] = 1;
int u = lst;
while(u != -1 && trans[u][id] == -1){
trans[u][id] = cnt;
u = link[u];
}
if (u == -1){
link[cnt] = 0;
lst = cnt;
return;
}
int q = trans[u][id];
if (maxlen[q] == maxlen[u] + 1){
link[cnt] = q;
lst = cnt;
return;
}
int cur = cnt, sq = cnt+1;
newNode();
maxlen[sq] = maxlen[u] + 1;
memcpy(trans[sq], trans[q], sizeof trans[q]);
link[sq] = link[q];
link[q] = sq;
link[cur] = sq;
lst = cur;
if (len_p[q] != -1 && len_p[q] <= maxlen[sq]){
len_p[sq] = len_p[q];
len_p[q] = -1;
}
while(u != -1 && trans[u][id] == q){
trans[u][id] = sq;
u = link[u];
}
return;
}
void init(){
cnt = -1;
newNode();
maxlen[0] = 0;
lst = 0;
}
void construct(char* s){
init();
for (int i = 0; s[i]; ++i){
insert(id(s[i]));
int bef = link[lst];
// cout << s[i] << ' ' << M[i] << ' ' << maxlen[bef] << endl;
if (M[i] > maxlen[bef]) len_p[lst] = M[i];
}
while(sta.size()) sta.pop();
memset(in, 0, sizeof in);
for (int i = 1; i <= cnt; ++i) in[link[i]]++;
queue<int> q;
for (int i = 1; i <= cnt; ++i) if (!in[i]) q.push(i);
while(q.size()){
int top = q.front(); q.pop(); sta.push(top);
if (top == 0) continue;
times[link[top]] += times[top];
if (--in[link[top]] == 0) q.push(link[top]);
}
while(sta.size()){
int top = sta.top(); sta.pop();
if (top == 0) continue;
plus[top] = plus[link[top]];
if (len_p[top] != -1) plus[top] += times[top];
}
}
ll match(char *s){
ll sum = 0;
int cur = 0, len = 0;
for (int i = 0; s[i]; ++i){
int j = id(s[i]);
while(cur && trans[cur][j] == -1) cur = link[cur], len = maxlen[cur];
if (trans[cur][j] != -1){
cur = trans[cur][j];
// cout << i << ' ' << maxlen[cur] << ' ' << plus[cur] << endl;
len++;
// cout << len_p[cur] << endl;
if (len_p[cur] != -1 && len_p[cur] <= len) sum += plus[cur];
else sum += plus[link[cur]];
}
}
return sum;
}
}SAM;
char a[maxn], b[maxn];
int main(){
//必须编译过才能交
int ik, i, j, k, kase;
//SAM要求有两倍字符串长度的节点。
//匹配时 --i 和 cur == 0两个需要注意的地方
scanf("%d", &kase);
for (ik = 1; ik <= kase; ++ik){
scanf("%s%s", a, b);
manacher(a);
SAM.construct(a);
printf("Case #%d: %I64d\n", ik, SAM.match(b));
}
return 0;
}
/*
10
ewewewwe
eeewwwwq
ewewewwe
eee
*/