题目大意:在一个字符串序列中,求有多少个不连续的回文子串 (整体不能全部连续,这个子串的部分可以连续).
初次学FFT对这种题是完全懵逼的,感谢大佬的题解:
https://blog.csdn.net/popoqqq/article/details/42193259
答案可以用所有的回文串扣掉连续的回文串,连续的回文串可以用马拉车O(N)求出。
所有回文串:设 F[i] 为 以第i个字母为对称中心时两端对称相等的字符对数,则以F[i]为中心的所有回文子串为:
2
F
[
i
]
2^{F[i]}
2F[i] - 1。
如何求F[i]:分别求字母’a’和字母’b’的贡献。若当前要求字母c的贡献,对串进行处理:若第i个位置为c,则第i个位置的权值为1,否则为0,则F[i] = ∑ j = 0 2 i − 1 v a l [ j ] ∗ v a l [ 2 i − 1 − j ] 2 \sum_{j=0}^{2i - 1}\frac{val[j]*val[2i - 1- j]}{2} ∑j=02i−12val[j]∗val[2i−1−j],这正是原串对原串的卷积。
用FFT算法可以快速求出所有的F[I]。如果对称轴不是字符而是对称线,卷积也能处理到这种情况。因为卷积是所有前缀的交叉相乘,前缀长度为奇数时对称轴是字符,为偶数时对称轴是线,不会遗漏答案。
代码
#include<iostream>
#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<math.h>
using namespace std;
const int maxn = 4e5 + 10;
const double pi = acos(-1.0);
const int mod = 1e9 + 7;
char x[maxn],tmp[maxn],y[maxn];
int top = 0,sz = 0;
int len[maxn];
struct complex{
double r,i;
complex(double _r = 0.0,double _i = 0.0) {
r = _r;
i = _i;
}
complex operator + (const complex & rhs) {
return complex(rhs.r + r,rhs.i + i);
}
complex operator - (const complex & rhs) {
return complex(r - rhs.r,i - rhs.i);
}
complex operator * (const complex & rhs) {
return complex(r * rhs.r - i * rhs.i,r * rhs.i + i * rhs.r);
}
};
complex A[maxn],B[maxn];
int ans[maxn],sum[maxn];
void change(complex a[],int len) {
int tot = 0;
while((1 << tot) < len) tot++;
tot--;
for(int i = 0; i < len; i++) {
int cur = 0;
for(int j = 0; j <= tot; j++)
if(i & (1 << j))
cur |= 1 << (tot - j);
if(i < cur) swap(a[i],a[cur]);
}
}
void fft(complex a[],int len,int type) {
change(a,len);
for(int i = 2; i <= len; i <<= 1) {
complex wp = complex(cos(2 * pi * type / i),sin(2 * pi * type / i));
for(int j = 0; j < len; j += i) {
complex w = complex(1,0);
for(int k = 0; k < i / 2; k++) {
complex t = a[j + k];
complex u = w * a[j + k + i / 2];
a[j + k] = t + u;
a[j + k + i / 2] = t - u;
w = w * wp;
}
}
}
if(type == -1) {
for(int i = 0; i < len; i++)
a[i].r /= len;
}
}
int solve(char c) {
int len = 1;
while(len < 2 * top) len <<= 1;
for(int i = 0; i < top; i++)
A[i] = complex(x[i] == c,0);
for(int i = top; i < len; i++)
A[i] = complex(0,0);
fft(A,len,1);
for(int i = 0; i < len; i++)
A[i] = A[i] * A[i];
fft(A,len,-1);
for(int i = 0; i < len; i++) {
ans[i] += (int) (A[i].r + 0.5);
}
return len;
}
long long manacher() {
int id = 0,mx = -1;
long long ans = 0;
tmp[sz++] = '$';
for(int i = 0; i < top; i++) {
tmp[sz++] = '#';
tmp[sz++] = x[i];
}
tmp[sz++] = '#';
len[0] = 1;
for(int i = 0; i < sz; i++) {
if(i < mx) len[i] = min(len[2 * id - i],mx - i);
else len[i] = 1;
while(tmp[i + len[i]] == tmp[i - len[i]]) len[i]++;
if(i + len[i] > mx) {
mx = i + len[i];
id = i;
}
ans += len[i] / 2;
ans %= mod;
}
return ans;
}
long long fpow(long long a,long long b) {
long long r = 1;
while(b) {
if(b & 1) r *= a,r %= mod;
a *= a;
a %= mod;
b >>= 1;
}
return r;
}
int main() {
scanf("%s",x);
for(int i = 0; x[i]; i++) {
if(x[i] == 'b' || x[i] == 'a')
x[top++] = x[i];
}
x[top] = 0;
long long t = manacher();
long long p = 0;
int z = solve('a');
solve('b');
for(int i = 0; i < z; i++) {
p += fpow(2,ans[i] + 1 >> 1) - 1;
p %= mod;
}
printf("%lld\n",(p + mod - t) % mod);
return 0;
}