String Magic (Easy Version)
Problem Description
Z is learning string theory and he finds a difficult problem.
Given a string S S S of length n n n (indexed from 1 1 1 to n n n) , define f ( S ) f(S) f(S) equal to the number of pair ( i , j ) (i,j) (i,j) that:
-
1 ≤ i < j ≤ n 1≤i<j≤n 1≤i<j≤n
-
j − i + 1 = 2 k , k > 0 j−i+1=2k,k>0 j−i+1=2k,k>0 ( j − i + 1 j−i+1 j−i+1 is even)
-
S [ i , i + k − 1 ] = S [ i + k , j ] S[i,i+k−1]=S[i+k,j] S[i,i+k−1]=S[i+k,j]
-
S [ i , i + k − 1 ] S[i,i+k−1] S[i,i+k−1] is a palindrome
Here S [ L , R ] S[L,R] S[L,R] denotes the substring of S S S with index from L L L to R R R.
A palindrome is a string that reads the same from left to right as from right to left.
To solve this problem, Z Z Z needs to calculate f ( S ) f(S) f(S).
He doesn’t know how to solve it, but he knows it’s easy for you. Please help him.
Input
The first line contains one integer T T T ( 1 ≤ T ≤ 10 1≤T≤10 1≤T≤10) which represents the number of test cases.
For each test case: One line contains a string S S S ( 1 ≤ ∣ S ∣ ≤ 1 0 5 1≤|S|≤10^5 1≤∣S∣≤105).
It’s guaranteed that the string only contains lowercase letters.
Output
For each test case: Print one line containing one integer which represents f ( S ) f(S) f(S).
Sample Input
3
aaaa
abaaba
ababa
Sample Output
4
2
0
#include <iostream>
#include <cstring>
using namespace std;
using i64 = long long;
using u64 = unsigned long long;
const int N = 100010, P = 13331;
int n;
int tr[N][26];
int fail[N], len[N];
int idx, last;
u64 h[N], hr[N], p[N];
char s[N];
i64 res, sum[N];
u64 ask1(int l, int r) {
return h[r] - h[l - 1] * p[r - l + 1];
}
u64 ask2(int l, int r) {
return hr[l] - hr[r + 1] * p[r - l + 1];
}
int newnode(int l) {
len[idx] = l, fail[idx] = 0;
memset(tr[idx], 0, sizeof tr[idx]);
return idx++;
}
void init() {
idx = last = 0;
newnode(0), newnode(-1);
s[0] = -1, fail[0] = 1;
}
int get_fail(int p, int i) {
while (s[i - len[p] - 1] != s[i]) p = fail[p];
return p;
}
bool check(int l, int r) {
int len = r - l + 1;
if (len & 1) return false;
len /= 2;
return ask1(l, l + len - 1) == ask2(l, l + len - 1);
}
void insert(int i) {
int u = s[i] - 'a', p = get_fail(last, i);
if (!tr[p][u]) {
int now = newnode(len[p] + 2);
fail[now] = tr[get_fail(fail[p], i)][u];
sum[now] = sum[fail[now]] + check(i - len[now] + 1, i);
tr[p][u] = now;
}
last = tr[p][u];
res += sum[last];
}
void solve() {
scanf("%s", s + 1);
init();
int n = strlen(s + 1);
h[0] = hr[n + 1] = 0;
for (int i = 1; i <= n; i++) {
h[i] = h[i - 1] * P + s[i];
}
for (int i = n; i; i--) {
hr[i] = hr[i + 1] * P + s[i];
}
res = 0;
for (int i = 1; i <= n; i++) {
insert(i);
}
printf("%llu\n", res);
}
int main() {
p[0] = 1;
for (int i = 1; i <= 100000; i++) {
p[i] = p[i - 1] * P;
}
int T;
scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}