题意
定义一个字符串合法当且仅当:
- 它为空串;
- 它形如 a S a \texttt aS\texttt a aSa、 b S b \texttt bS\texttt b bSb 或 c S c \texttt cS\texttt c cSc,其中 S S S 是合法的;
- 它形如 S T ST ST,其中 S S S、 T T T 都是合法的。
给出字符串 S S S,问:有多少种交换两不同字符的方案,使得交换后的 S S S 合法。
∣ S ∣ ≤ 1 0 5 |S|\leq 10^5 ∣S∣≤105,3s。
题解
假如一个字符串每次删除相邻的两个相同字符,最后变成空串,那它显然合法。我们定义 f ( S ) f(S) f(S) 表示 S S S 每次删除相邻的两个相同字符直到不能再删为止,定义 S ′ S' S′ 为 S S S 的反串。 f ( S ) f(S) f(S) 能够做到 O ( 1 ) O(1) O(1) 维护在末尾增添字符、 O ( log ) O(\log) O(log) 合并两个串(二分+哈希)。
考虑分治:找到 m i d mid mid;把 [ l , m i d ] [l,mid] [l,mid] 中的某个字符改为另一个,求出更改后的 f ( S [ l , m i d ] ) f(S[l,mid]) f(S[l,mid]);把 [ m i d + 1 , r ] [mid+1,r] [mid+1,r] 中的某个字符改为另一个,看有多少更改的字符恰好对应,且 f ( S [ l , m i d ] ) = f ( S [ m i d + 1 , r ] ) ′ f(S[l,mid])=f(S[mid+1,r])' f(S[l,mid])=f(S[mid+1,r])′ 的在左侧更改的方式;最后递归左半边、右半边。
时间复杂度 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)
代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define sz size()
const int mod1 =1e9+7,mod2 =1e9+9,
base1=5 ,base2=13 ;
const int N=1e5+10;
pii base(base1,base2);
inline pii operator+(pii a,pii b){
pii c=mp(a.fi+b.fi,a.se+b.se);
while(c.fi>=mod1)c.fi-=mod1;
while(c.se>=mod2)c.se-=mod2;
return c;
}
inline pii operator+(pii a,int b){ return a+mp(b,b); }
inline pii operator*(pii a,pii b){ return mp(a.fi*1ll*b.fi%mod1,a.se*1ll*b.se%mod2); }
inline pii operator*(pii a,int b){ return a*mp(b,b); }
inline pii operator-(pii a,pii b){ return a+mp(mod1-b.fi,mod2-b.se); }
char s[N];
pii f[N];
int n;
struct str{
int n;
vector<char>x;
vector<pii >a,b;
inline int size() const { return x.sz; }
inline void push(char c){
int m=x.sz;
x.pb(c);
a.pb(a.back()*base+(c-'a'+1));
b.pb(b.back()+f[m-1]*(c-'a'+1));
}
inline void del(){
x.pop_back();a.pop_back();b.pop_back();
}
inline void add(char c){
if(x.back()==c)del();
else push(c);
}
inline void movel(int m){
m=max(m,0);
while(n>m)add(s[n--]);
while(n<m)add(s[++n]);
}
inline void mover(int m){
m=min(m,::n+1);
while(n>m)add(s[--n]);
while(n<m)add(s[n++]);
}
inline void init(int m){
n=m;
x.clear(),a.clear(),b.clear();
x.pb(0);
a.pb(mp(0,0));b.pb(mp(0,0));
}
inline pii get(int k){
return a[sz-1]-a[sz-k-1]*f[k];
}
}a,b,c;
ostream& operator<<(ostream &out,const str &s){
for(int i=1;i<s.sz;i++)out<<s.x[i];
return out;
}
inline int calc(int i,int j){
return i*2+j-(j>i);
}
inline pii merge(str &l,char c,str &r){
int k=0;
l.add(c);
for(int i=16;i>=0;--i)
if(k+(1<<i)<=min(l.sz,r.sz)-1&&l.get(k+(1<<i))==r.get(k+(1<<i)))
k+=1<<i;
pii ans=l.a[l.sz-k-1]*f[r.sz-k-1]+r.b[r.sz-k-1];
l.add(c);
//cerr<<"merge "<<l<<"+"<<c<<"+"<<r<<" "<<ans.fi<<" "<<ans.se<<endl;
return ans;
}
map<pii,int>w[6];
ll ans=0;
void solve(int l,int r){
if(l==r)return;
//cerr<<"solve "<<l<<" "<<r<<endl;
int mid=l+r>>1;
a.movel(mid-1);
c.init(mid+1);
for(int i=mid;i>=l;i--,a.movel(i-1),c.mover(i+1))
for(int j=0;j<=2;j++)
if(s[i]!='a'+j)
w[calc(s[i]-'a',j)][merge(a,j+'a',c)]++;
b.mover(mid+2);
c.init(mid);
for(int i=mid+1;i<=r;i++,b.mover(i+1),c.movel(i-1))
for(int j=0;j<=2;j++)
if(s[i]!='a'+j)
ans+=w[calc(j,s[i]-'a')][merge(b,j+'a',c)];
for(int i=0;i<6;i++)w[i].clear();
solve(l,mid);
solve(mid+1,r);
}
int main(){
scanf("%s",s+1);
n=strlen(s+1);
a.init(0);
b.init(n+1);
f[0]=mp(1,1);
for(int i=1;i<=n;i++)f[i]=f[i-1]*base;
solve(1,n);
printf("%lld\n",ans);
return 0;
}