题意
定义在确定了模式串
s
t
r
str
str 的情况下,字符串
S
S
S 的权值为
s
t
r
str
str 在
S
S
S 中出现的次数。现给定字符串
S
S
S,
Q
Q
Q 次询问:指定
S
S
S 的某个子串为模式串,求
S
S
S 的所有本质不同子串的权值和。
∣
S
∣
,
Q
≤
4
×
1
0
5
|S|,Q\leq 4\times 10^5
∣S∣,Q≤4×105。时限 3s 1s。
题解
(官方题解被某神仙 D 了,说是官方题解数据结构学傻了)
建 S S S 的 SAM。记询问串为 T T T。
首先换个思路:如果 T T T 后面接上一个字符串成为 Q Q Q, Q Q Q 在前面接上一个字符串成为 R R R,且 R R R 出现在了 S S S 中,则 R R R 会对权值产生 1 的贡献。
先考虑已经确定了 Q Q Q,如何数 R R R,即统计后缀为 Q Q Q 的字符串个数。 Q Q Q 对应的节点(记为 q q q)在 SAM 的 fail 树上的子树中,不是 q q q 的所有节点都会有贡献,而 q q q 上的字符串,长度不少于 ∣ Q ∣ |Q| ∣Q∣ 的也会有贡献,所以 q q q 对权值的贡献可以表示为 − ∣ Q ∣ + b -|Q|+b −∣Q∣+b,其中 b b b 是一个常数。
给定 T T T 后,在 T T T 后面加字符就相当于在后缀自动机的 DAG 上走,所以在 DAG 上再做一次 DP,以统计 t t t 对权值的贡献(可以表示为 k ∣ T ∣ + b k|T|+b k∣T∣+b, k , b k,b k,b 都是常数)。转移时注意 k ∣ T ∣ + b k|T|+b k∣T∣+b 会变成 k ( ∣ T ∣ + 1 ) + b = k ∣ T ∣ + ( k + b ) k(|T|+1)+b=k|T|+(k+b) k(∣T∣+1)+b=k∣T∣+(k+b)。每次询问找到 T T T 对应的节点,把 T T T 代进去算即可。
(然而我直到做这道题之前还不知道如何找到 T ( S [ l … r ] ) T(S[l\dots r]) T(S[l…r]) 的对应节点:先找到 S [ 1 … r ] S[1\dots r] S[1…r] 的对应节点(可以在构建 SAM 时顺便记录),然后倍增/树剖跳 f a i l fail fail 直到最后一个 l e n ≥ ∣ T ∣ len\geq |T| len≥∣T∣ 的节点)
时间复杂度 O ( ∣ S ∣ σ + Q log ∣ S ∣ ) O(|S|\sigma +Q\log |S|) O(∣S∣σ+Qlog∣S∣)。
代码:
#include<bits/stdc++.h>
using namespace std;
int getint(){
int ans=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
ans=ans*10+c-'0';
c=getchar();
}
return ans*f;
}
const int N=4e5+10,S=26;
struct node{
int ch[S];
int fail,len;
};
node sam[N<<1];
int cnt=0,lst=0;
void init(){
memset(sam,-1,sizeof(sam));
sam[cnt++].len=0;
}
int pos[N];
void extend(char c,int p){
int cur=cnt++;
sam[cur].len=sam[lst].len+1;
for(;(~lst)&&(!~sam[lst].ch[c-'a']);lst=sam[lst].fail)
if(!~sam[lst].ch[c-'a'])sam[lst].ch[c-'a']=cur;
if(!~lst)sam[cur].fail=0;
else{
int q=sam[lst].ch[c-'a'];
if(sam[q].len==sam[lst].len+1)sam[cur].fail=q;
else{
int clone=cnt++;
memcpy(sam+clone,sam+q,sizeof(node));
sam[clone].len=sam[lst].len+1;
sam[q].fail=sam[cur].fail=clone;
for(;sam[lst].ch[c-'a']==q;lst=sam[lst].fail)
sam[lst].ch[c-'a']=clone;
}
}
lst=cur;
pos[p]=lst;
}
void print(){
for(int i=0;i<cnt;i++){
cerr<<"["<<i<<"] "<<sam[i].len<<endl;
for(int j=0;j<S;j++){
if(~sam[i].ch[j])cerr<<i<<" "<<sam[i].ch[j]<<" "<<char(j+'a')<<endl;
}
}
}
char str[N];
struct bian{
int e,n;
};
bian b[N<<2];
int s[N<<1],tot=0;
void add(int x,int y){
tot++;
b[tot].e=y;
b[tot].n=s[x];
s[x]=tot;
}
int d[N<<1];
bool cmp(int x,int y){
return sam[x].len<sam[y].len;
}
#define cp complex<long long>
long long sz[N<<1];
cp f[N<<1];//a+bx
long long wei[N<<1],top[N<<1],dfn[N<<1],maxch[N<<1];
vector<int>chain[N<<1];
int find(int x,int l){
while((~sam[top[x]].fail)&&sam[sam[top[x]].fail].len>=l){
//cerr<<"find "<<x<<" -> "<<sam[top[x]].fail<<endl;
x=sam[top[x]].fail;
}
sam[cnt].len=l;
int ans=*lower_bound(chain[top[x]].begin(),chain[top[x]].end(),cnt,cmp);
return ans;
}
int main(){
init();
int n=getint(),m=getint();
scanf("%s",str+1);
for(int i=1;i<=n;i++){
extend(str[i],i);
}
//print();cerr<<endl;
for(int i=0;i<cnt;i++){
d[i]=i;
sz[i]=sam[i].len-sam[sam[i].fail].len;
if(~sam[i].fail){
add(sam[i].fail,i);
//cerr<<"> "<<sam[i].fail<<" "<<i<<endl;
}
maxch[i]=-1;
}
sort(d,d+cnt,cmp);
for(int i=cnt-1;i>=0;--i){
int x=d[i];
sz[sam[x].fail]+=sz[x];
wei[sam[x].fail]+=wei[x];
f[x]=cp(sz[x]+sam[sam[x].fail].len+1,-1);
//cerr<<x<<" "<<sz[x]<<" "<<f[x]<<endl;
}
for(int i=cnt-1;i>=0;--i){
int x=d[i];
for(int j=0;j<S;j++){
if(~sam[x].ch[j]){
int v=sam[x].ch[j];
f[x]+=cp(f[v].real()+f[v].imag(),f[v].imag());
//cerr<<"> "<<v<<"->"<<x<<endl;
}
}
//cerr<<x<<" "<<f[x]<<endl;
}
for(int i=0;i<cnt;i++){
for(int j=s[i];j;j=b[j].n){
if((!~maxch[i])||sz[b[j].e]>sz[maxch[i]])
maxch[i]=b[j].e;
}
}
for(int i=0;i<cnt;i++){
int x=d[i];
if(x==maxch[sam[x].fail])top[x]=top[sam[x].fail];
else top[x]=x;
//cerr<<"top "<<x<<" "<<top[x]<<endl;
chain[top[x]].push_back(x);
}
while(m--){
int l=getint(),r=getint();
int x=find(pos[r],r-l+1);
//cerr<<"find "<<pos[r]<<" "<<x<<endl;
printf("%lld\n",f[x].real()+f[x].imag()*(r-l+1ll));
}
return 0;
}