题目背景
为了研究 OID 打牌,OID 出了这么一道题
题面
我们按以下的方式定义一个字符串是否是好牌:
- 空串是好牌;
- 如果 S S S 是好牌,那么 a S a aSa aSa、 b S b bSb bSb、 c S c cSc cSc 是好牌;
- 如果 S S S 和 S S S 都是好牌,那么 S T ST ST 是好牌;
- 不能被以上方式定义为好牌的字符串都不是好牌。
也就是说,一个串是好牌,当且仅当它是一个只包含 a
、 b
、 c
的合法引号序列。
给定一个只包含 a
、 b
、 c
的字符串 ,求有多少种方案交换两个不同字符,使得交换之后该串是好牌。
∣ S ∣ ≤ 1 0 5 |S|\leq 10^5 ∣S∣≤105 .
题解
为了将 OID 的牌术研究透彻,OID 命令题解要以 OID 的文风来写
好牌等价于每次选任意两个相邻相同的字符删掉,可以把整个串消完。于是我们就可以模拟这个过程判断好牌。两个串拼起来是好牌,当且仅当分别把它们相邻相同的消尽后,两个串对称相等。
同样地,可以模拟将两个串消尽后的部分拼起来,在分界线处不断删除相同字符,最终可以得到合并串的消除结果。
于是,就有一个 O ( n 2 ) O(n^2) O(n2) 的做法,把前缀后缀的消串结果用 trie 树存下来,然后枚举交换的两点,右端点从左往右扫的过程中将右端点前的字符串的消串结果处理出来。用哈希判断对称相等。
考后我想了想,我为什么没能想到分治正解。我试图将题面改了下,变成在一棵树上,求路径变成好牌,虽然不会做,但是果然很容易就想到点分治……
这道题用分治怎么做呢?就是将每一个交换的方案放到分治区间里,左端点在左半边统计,右端点在右半边统计,递归下去。
我们可以通过前缀的信息 + 暴力扫描 + 二分法合并,均摊 O ( log n ) O(\log n) O(logn) 得到左半边串在交换后的所有 O ( n ) O(n) O(n) 个串的哈希值,右半边同理。于是通过哈希表统计答案。
CODE
比较卡哈希,一定要同时以哈希值和串长为关键字。
#include<map>
#include<set>
#include<cmath>
#include<ctime>
#include<queue>
#include<stack>
#include<random>
#include<bitset>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<unordered_map>
// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast")
using namespace std;
#define MAXN 100005
#define LL long long
#define ULL unsigned long long
#define ENDL putchar('\n')
#define DB double
#define lowbit(x) (-(x) & (x))
#define FI first
#define SE second
int xchar() {
static const int maxn = 1000000;
static char b[maxn];
static int pos = 0,len = 0;
if(pos == len) pos = 0,len = fread(b,1,maxn,stdin);
if(pos == len) return -1;
return b[pos ++];
}
// #define getchar() xchar()
LL read() {
LL f = 1,x = 0;int s = getchar();
while(s < '0' || s > '9') {if(s<0)return -1;if(s=='-')f=-f;s = getchar();}
while(s >= '0' && s <= '9') {x = (x<<1) + (x<<3) + (s^48);s = getchar();}
return f*x;
}
void putpos(LL x) {if(!x)return ;putpos(x/10);putchar((x%10)^48);}
void putnum(LL x) {
if(!x) {putchar('0');return ;}
if(x<0) putchar('-'),x = -x;
return putpos(x);
}
void AIput(LL x,int c) {putnum(x);putchar(c);}
const int MOD = 993244853;
int n,m,s,o,k;
const int B = 37;
char ss[MAXN];
int sf[MAXN],pr[MAXN],pw[MAXN];
int t1[MAXN],t2[MAXN];
int f1[MAXN][20],f2[MAXN][20],d1[MAXN],d2[MAXN];
int h1[MAXN][20],h2[MAXN][20];
int hs[MAXN],sta[MAXN],tp,hs2[MAXN];
LL ans = 0;
void solve(int l,int r) {
if(l >= r) return ;
int md = (l + r) >> 1;
map<LL,int> mp[3][3];
sta[tp = 0] = -1; hs[0] = 0;
for(int i = md;i >= l;i --) {
int me = ss[i]-'a';
for(int c = 0;c < 3;c ++) {
if(c == me) continue;
int st = tp+1; sta[st] = c;
hs[st] = (hs[st-1]*1ll*B + c+1) % MOD;
hs2[st] = (pw[st-1]*1ll*(c+1) + hs2[st-1]) % MOD;
if(c == sta[st-1]) st -= 2;
int p = t1[i-1];
for(int j = 16;j >= 0;j --) {
if(st >= (1<<j) && h1[p][j] == (hs[st]+MOD-hs[st-(1<<j)]*1ll*pw[1<<j]%MOD) % MOD) {
st -= (1<<j); p = f1[p][j];
}
}
int has = (pr[p]*1ll*pw[st] + hs2[st]) % MOD;
int le = d1[p] + st;
mp[me][c][has*1000000ll + le] ++;
}
if(me == sta[tp]) tp --;
else {
sta[++ tp] = me;
hs[tp] = (hs[tp-1]*1ll*B + me+1) % MOD;
hs2[tp] = (pw[tp-1]*1ll*(me+1) + hs2[tp-1]) % MOD;
}
}
sta[tp = 0] = -1; hs[0] = 0;
for(int i = md+1;i <= r;i ++) {
int me = ss[i]-'a';
for(int c = 0;c < 3;c ++) {
if(c == me) continue;
int st = tp+1; sta[st] = c;
hs[st] = (hs[st-1]*1ll*B + c+1) % MOD;
hs2[st] = (pw[st-1]*1ll*(c+1) + hs2[st-1]) % MOD;
if(c == sta[st-1]) st -= 2;
int p = t2[i+1];
for(int j = 16;j >= 0;j --) {
if(st >= (1<<j) && h2[p][j] == (hs[st]+MOD-hs[st-(1<<j)]*1ll*pw[1<<j]%MOD) % MOD) {
st -= (1<<j); p = f2[p][j];
}
}
int has = (sf[p]*1ll*pw[st] + hs2[st]) % MOD;
int le = d2[p] + st;
ans += mp[c][me][has*1000000ll + le];
}
if(me == sta[tp]) tp --;
else {
sta[++ tp] = me;
hs[tp] = (hs[tp-1]*1ll*B + me+1) % MOD;
hs2[tp] = (pw[tp-1]*1ll*(me+1) + hs2[tp-1]) % MOD;
}
}
solve(l,md); solve(md+1,r);
return ;
}
int main() {
freopen("string.in","r",stdin);
freopen("string.out","w",stdout);
scanf("%s",ss + 1);
n = strlen(ss + 1);
pw[0] = 1;
for(int i = 1;i <= n;i ++) pw[i] = pw[i-1] *1ll* B % MOD;
sf[n+1] = 0; pr[0] = 0;
ss[0] = ss[n+1] = 0;
for(int i = 1;i <= n;i ++) {
if(ss[i] == ss[t1[i-1]]) t1[i] = f1[t1[i-1]][0];
else t1[i] = i,f1[i][0] = t1[i-1];
}
for(int i = 1;i <= n;i ++) {
if(t1[i] == i) {
h1[i][0] = ss[i]-'a'+1;
int pw = B;
for(int j = 1;j <= 17;j ++) {
f1[i][j] = f1[f1[i][j-1]][j-1];
h1[i][j] = (h1[f1[i][j-1]][j-1] *1ll* pw + h1[i][j-1]) % MOD;
pw = pw *1ll* pw % MOD;
}
pr[i] = (pr[f1[i][0]]*1ll*B + (ss[i]-'a'+1)) % MOD;
d1[i] = d1[f1[i][0]] + 1;
}
}
for(int i = n;i > 0;i --) {
if(ss[i] == ss[t2[i+1]]) t2[i] = f2[t2[i+1]][0];
else t2[i] = i,f2[i][0] = t2[i+1];
}
for(int i = n;i > 0;i --) {
if(t2[i] == i) {
h2[i][0] = ss[i]-'a'+1;
int pw = B;
for(int j = 1;j <= 17;j ++) {
f2[i][j] = f2[f2[i][j-1]][j-1];
h2[i][j] = (h2[f2[i][j-1]][j-1] *1ll* pw + h2[i][j-1]) % MOD;
pw = pw *1ll* pw % MOD;
}
sf[i] = (sf[f2[i][0]]*1ll*B + (ss[i]-'a'+1)) % MOD;
d2[i] = d2[f2[i][0]] + 1;
}
}
solve(1,n);
AIput(ans,'\n');
return 0;
}