题意:给定两个字符串s,t。从s中找一个子串ss,从t中找一个前缀tt,len(ss)>len(tt),求ss+tt是回文串的方案数。
分析:简单来说,将s串翻转,然后枚举t的前缀,找出s中和该前缀相等的子串,对每个这样的子串,找出该子串前面长度不为0的回文串的个数,累加一下答案就可以了。
首先把s串翻转,用回文树处理出翻转之后的s串的每个点,以该点结尾的回文串个数,然后将翻转后的s和t拼接到一起跑一遍后缀数组,由于lcp单调,从长到短枚举t的前缀,然后累加答案就可以了。
由于此题串长度为1e6,并且时限只有1s,所以需要用DC3来替代倍增的SA。
#include <bits/stdc++.h>
using namespace std;
#define F(x) ((x)/3+((x)%3==1?0:tb))
#define G(x) ((x)<tb?(x)*3+1:((x)-tb)*3+2)
const int N = 1e6 + 4;
const int MAXN = 2e7 + 100;
struct Palindromic_Tree {
int nxt[N][30], fail[N], cnt[N];
int num[N], len[N], s[N], id[N];
long long cc[N], c_num[N];
int last, n, p;
int newnode(int l) {
memset(nxt[p], 0, sizeof(nxt[p]));
cnt[p] = num[p] = 0;
len[p] = l;
return p++;
}
void init() {
p = 0;
newnode(0), newnode(-1);
last = n = 0;
s[0] = -1;
fail[0] = 1;
}
int get_fail(int x) {
while (s[n - len[x] - 1] != s[n]) x = fail[x];
return x;
}
void add(int c) {
c -= 'a';
s[++n] = c;
int cur = get_fail(last);
if (!nxt[cur][c]) {
int now = newnode(len[cur] + 2);
fail[now] = nxt[get_fail(fail[cur])][c];
nxt[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = nxt[cur][c];
cnt[last]++, id[last] = n;
cc[n] = last;
}
void Count() {
for (int i = p - 1; i >= 0; i--) cnt[fail[i]] += cnt[i];
for (int i = 1; i <= n; ++i) {
c_num[i - 1] = num[cc[i]];
}
}
} pam;
struct DC3 {
int sa[MAXN];
int rk[MAXN];
int height[MAXN];
int n, m;
int s[MAXN];
int wa[MAXN], wb[MAXN], wv[MAXN];
int wws[MAXN];
void sort(int *r, int *a, int *b, int n, int m) {
int i;
for (i = 0; i < n; i++) wv[i] = r[a[i]];
for (i = 0; i < m; i++) wws[i] = 0;
for (i = 0; i < n; i++) wws[wv[i]]++;
for (i = 1; i < m; i++) wws[i] += wws[i - 1];
for (i = n - 1; i >= 0; i--) b[--wws[wv[i]]] = a[i];
return;
}
int c0(int *r, int a, int b) { return r[a] == r[b] && r[a + 1] == r[b + 1] && r[a + 2] == r[b + 2]; }
int c12(int k, int *r, int a, int b) {
if (k == 2) return r[a] < r[b] || r[a] == r[b] && c12(1, r, a + 1, b + 1);
else return r[a] < r[b] || r[a] == r[b] && wv[a + 1] < wv[b + 1];
}
void Suffix() {
dc3(s, sa, n, m);
}
void getheight() {
calheight(s, sa, n - 1);
}
void dc3(int *r, int *sa, int n, int m) {
int i, j, *rn = r + n, *san = sa + n, ta = 0, tb = (n + 1) / 3, tbc = 0, p;
r[n] = r[n + 1] = 0;
for (i = 0; i < n; i++) if (i % 3 != 0) wa[tbc++] = i;
sort(r + 2, wa, wb, tbc, m);
sort(r + 1, wb, wa, tbc, m);
sort(r, wa, wb, tbc, m);
for (p = 1, rn[F(wb[0])] = 0, i = 1; i < tbc; i++)
rn[F(wb[i])] = c0(r, wb[i - 1], wb[i]) ? p - 1 : p++;
if (p < tbc) dc3(rn, san, tbc, p);
else for (i = 0; i < tbc; i++) san[rn[i]] = i;
for (i = 0; i < tbc; i++) if (san[i] < tb) wb[ta++] = san[i] * 3;
if (n % 3 == 1) wb[ta++] = n - 1;
sort(r, wb, wa, ta, m);
for (i = 0; i < tbc; i++) wv[wb[i] = G(san[i])] = i;
for (i = 0, j = 0, p = 0; i < ta && j < tbc; p++)
sa[p] = c12(wb[j] % 3, r, wa[i], wb[j]) ? wa[i++] : wb[j++];
for (; i < ta; p++) sa[p] = wa[i++];
for (; j < tbc; p++) sa[p] = wb[j++];
return;
}
void calheight(int *r, int *sa, int n) {
int i, j, k = 0;
for (i = 1; i <= n; ++i) rk[sa[i]] = i;
for (i = 0; i < n; height[rk[i++]] = k)
for (k ? k-- : 0, j = sa[rk[i] - 1]; r[i + k] == r[j + k]; ++k);
return;
}
} shit;
char s[N], t[N];
int main() {
shit.n = 0, shit.m = 128;
pam.init();
scanf("%s", t);
int n = strlen(t);
for (int i = 0; i < n; ++i)s[i] = t[n - i - 1];
for (int i = 0; i < n; ++i)pam.add(s[i]);
for (int i = 0; i < n; ++i)shit.s[shit.n++] = s[i];
shit.s[shit.n++] = 127;
scanf("%s", t);
int m = strlen(t);
for (int i = 0; i < m; ++i)shit.s[shit.n++] = t[i];
shit.s[shit.n++] = 0;
pam.Count();
shit.Suffix(), shit.getheight();
long long ans = 0, res = 0;
int pos = shit.rk[n + 1];
int l = pos, r = pos;
for (int len = m; len >= 1; --len) {
while (l >= 1 && shit.height[l] >= len) {
l--;
if (shit.sa[l] > 0 && shit.sa[l] < n)res += pam.c_num[shit.sa[l] - 1];
}
while (r < shit.n && shit.height[r + 1] >= len) {
r++;
if (shit.sa[r] > 0 && shit.sa[r] < n)res += pam.c_num[shit.sa[r] - 1];
}
ans += res;
}
cout << ans << endl;
}