题目链接
题意:
给出一个串串,要求选出两个子串,使他们拼接起来是一个回文串,设一种方案是[l1,r1]在前,[l2,r2]在后拼接而成,则这种方案用有序四元组(l1,r1,l2,r2)表示,求这样的四元组的个数。
题解:
首先,串串有2e5,如果他是2e5个a,那么答案大约是(2e5)^4,爆掉了ull。。。得开个 _ _ i n t _ 128 \_\_int\_128 __int_128 ,好坏怀阿。
我们分长度来讨论:
Case 1:[l1,r1]和[l2,r2]长度相等。
这种Case非常简单,我们考虑回文串的定义,他必须满足[l2,r2]反过来必须和[l1,r1]一模一样。那么我们把原串和反串插入到同一个SAM中,然后分别统计原串和反串的出现次数,然后这部分答案等于 ∑ S A M _ N o d e n u m 1 [ x ] ∗ n u m 2 [ x ] ∗ ( l e n [ x ] − l e n [ f a i l [ x ] ] ) \sum_{SAM\_Node}{num1[x] * num2[x] *(len[x] - len[fail[x]])} ∑SAM_Nodenum1[x]∗num2[x]∗(len[x]−len[fail[x]])
Case 2:[l1,r1] 长,[l2,r2]短
设[l1,r1]长度为len1,[l2,r2]长度为len2,这种Case 要求[l2,r2]反过来和[l1,r1]的前len2个字符相同,且[l1,r1]的第len2+1到len2个字符自身构成回文串。那么问题很显然了,我们要对与原串的每一个位置,他是多少个回文子串的开头位置。这个只需要Manacher处理回文半径,然后前缀和线性搞一波。
然后我们要做的是,把这个权重扔到SAM的节点上去,和SAM求每个节点串出现次数的方法一模一样。
然后这部分答案是:
∑
S
A
M
N
o
d
e
w
e
i
g
h
t
1
[
x
]
∗
n
u
m
2
[
x
]
∗
(
l
e
n
[
x
]
−
l
e
n
[
f
a
i
l
[
x
]
]
)
\sum_{SAM_Node}{weight1[x] * num2[x] *(len[x] - len[fail[x]])}
∑SAMNodeweight1[x]∗num2[x]∗(len[x]−len[fail[x]])
Case 3:[l1,r1]短,[l2,r2]长
这种情况由于我们要求的是回文串,他等价于把原串看作是反串,把反串看作是原串,做一次Case2,这部分答案的贡献是: ∑ S A M N o d e w e i g h t 2 [ x ] ∗ n u m 1 [ x ] ∗ ( l e n [ x ] − l e n [ f a i l [ x ] ] ) \sum_{SAM_Node}{weight2[x] * num1[x] *(len[x] - len[fail[x]])} ∑SAMNodeweight2[x]∗num1[x]∗(len[x]−len[fail[x]])
注意要开__int_128阿
题解是SA+单调栈+Manacher,复杂度是 O ( n l o g n ) O(nlogn) O(nlogn),我的SAM复杂度是 O ( n ∗ 26 ) O(n*26) O(n∗26)怎么都感觉应该打爆std的sa大常数阿。。然而实际上好像只快了一点点。
Code:
#include <bits/stdc++.h>
#define int ll
using namespace std;
typedef __int128_t ll;
const int maxn = 4e5+100;
char s[maxn];
char t[maxn];
long long n;
struct Suffix_Automaton{
//basic
int nxt[maxn*2][26],fa[maxn*2],l[maxn*2];
int last,cnt;
//extension
int cntA[maxn*2],A[maxn*2];/*辅助拓扑更新*/
int nums[maxn*2],numt[maxn*2];/*每个节点代表的所有串的出现次数*/
int weights[maxn*2],weightt[maxn*2];
void clear(){
last =cnt=1;
fa[1]=l[1]=0;
memset(nxt[1],0,sizeof nxt[1]);
}
void init(char *s){
while (*s){
add(*s-'a');
s++;
}
}
void add(int c){
int p = last;
int np = ++cnt;
memset(nxt[cnt],0,sizeof nxt[cnt]);
l[np] = l[p]+1;
last = np;
while (p&&!nxt[p][c])nxt[p][c] = np,p = fa[p];
if (!p)fa[np]=1;
else{
int q = nxt[p][c];
if (l[q]==l[p]+1)fa[np] =q;
else{
int nq = ++ cnt;
l[nq] = l[p]+1;
memcpy(nxt[nq],nxt[q],sizeof (nxt[q]));
fa[nq] =fa[q];
fa[np] = fa[q] =nq;
while (nxt[p][c]==q)nxt[p][c] =nq,p = fa[p];
}
}
}
void build(){
memset(cntA,0,sizeof cntA);
memset(nums,0,sizeof nums);
memset(numt,0,sizeof numt);
for (int i=1;i<=cnt;i++)cntA[l[i]]++;
for (int i=1;i<=n;i++)cntA[i]+=cntA[i-1];
for (int i=cnt;i>=1;i--)A[cntA[l[i]]--] =i;
/*更行主串节点*/
int temps=1,tempt = 1;
for (int i=1;i<=n;i++){
nums[temps = nxt[temps][s[i]-'a'] ]=1;
numt[tempt = nxt[tempt][t[i]-'a']] = 1;
}
/*拓扑更新*/
for (int i=cnt;i>=1;i--){
//basic
int x = A[i];
nums[fa[x]]+=nums[x];
numt[fa[x]] += numt[x];
}
}
void debug(){
for (int i=cnt;i>=1;i--){
printf("nums[%d]=%d numt[%d]=%d l[%d]=%d fa[%d]=%d\n",i,nums[i],i,numt[i],i,l[i],i,fa[i]);
}
}
ll query(){
ll res = 0;
for (int i=1;i<=cnt;i++){
res += 1LL*nums[i] * numt[i] * (l[i] - l[fa[i]]);
}
return res;
}
ll query2(){
ll res = 0;
for (int i=1;i<=cnt;i++){
res += 1LL*weights[i] * numt[i] *(l[i] - l[fa[i]]);
res += 1LL*weightt[i] * nums[i] *(l[i] - l[fa[i]]);
}
return res;
}
}sam;
struct Manacher{
char ch[maxn*2];
int lc[maxn*2];
int N;
void init(const char *s){
N = 2*n+1;
ch[N] = '#';
for (int i=n;i>=1;i--){
ch[i*2] = s[i];
ch[i*2-1] = '#';
}
ch[0] = 'z'+1;
ch[N+1] = '\0';
manacher();
}
void manacher(){
lc[1] = 1;
int k = 1;
for (int i=2;i<=N;i++){
int p = k + lc[k] -1;
if (i <= p){
lc[i] = min(lc[2*k-i],p-i+1);
}else{
lc[i] = 1;
}
while (ch[i+lc[i]] == ch[i-lc[i]])lc[i] ++;
if (i+lc[i] > k+lc[k]) k = i;
}
}
int query(int x){
return lc[x]>>1;
}
void debug(){
for (int i=1;i<=N;i++){
printf("lc[%d]=%d\n",i,lc[i]);
}
}
}mas,mat;
struct Prefix_Sum{
ll val[maxn*2];
void clear(){
memset(val,0,sizeof val);
}
void add(int l,int r,int delta){
val[l] += delta;
val[r+1] -= delta;
}
void build(){
//val[0] = 0;
for (int i=1;i<maxn;i++){
val[i] += val[i-1];
}
}
ll query(int x){
return val[x];
}
void debug(){
for (int i=1;i<=n;i++){
printf("val[%d]=%d\n",i,val[i]);
}
}
}sums,sumt;
ll calc(){
mas.init(s);
mat.init(t);
//mas.debug();
//mat.debug();
sums.clear();
sumt.clear();
for (int i=1;i<=n;i++){
sums.add(i-mas.query(i<<1),i-1,1);
sumt.add(i-mat.query(i<<1),i-1,1);
sums.add(i-mas.query(i<<1|1),i-1,1);
sumt.add(i-mat.query(i<<1|1),i-1,1);
//printf("lcs[%d]=%d lct[%d]=%d\n",i,mas.query(i),i,mat.query(i));
}
sums.build();
sumt.build();
//sums.debug();
//sumt.debug();
int temps = 1;
int tempt = 1;
for (int i=1;i<=n;i++){
temps = sam.nxt[temps][s[i] - 'a'];
tempt = sam.nxt[tempt][t[i] - 'a'];
sam.weights[temps] += sums.query(i);
sam.weightt[tempt] += sumt.query(i);
}
for (int i=sam.cnt;i>=1;i--){
//basic
int x = sam.A[i];
sam.weights[sam.fa[x]]+=sam.weights[x];
sam.weightt[sam.fa[x]]+=sam.weightt[x];
}
return sam.query2();
}
template <class T>
void print(T a)
{
if (a>9) print(a/10);
putchar(a%10+'0');
}
signed main(){
scanf("%lld",&n);
scanf("%s",s+1);
memcpy(t,s,sizeof t);
reverse(t+1,t+n+1);
//cout<<s+1<<" "<<t+1<<endl;
sam.clear();
sam.init(s+1);
sam.last = 1;
sam.init(t+1);
sam.build();
ll ans = sam.query();
//sam.debug();
//cout<<ans<<endl;
ans += calc();
print(ans);
return 0;
}