P6292 区间本质不同子串个数
题解
终于来了,典中典的SAM-LCT组合题!
这题要求的问题简单粗暴,就不要想着找什么性质了,直接上后缀自动机。
考虑用后缀自动机建出 P a r e n t T r e e \rm Parent\,Tree ParentTree,那么原字符串的前缀对应着树上的叶子,其到根的一条链上包含了这个前缀的所有后缀的信息。
朴素的思路是把询问离线下来做扫描线,把每种子串的最靠后的左端点放到线段树上,就可以用区间求和查询答案。
当我们枚举到某个前缀的时候,在 P a r e n t T r e e \rm Parent\,Tree ParentTree 上跳到根,把到根的链上的每个节点的子串的右端点更新,并在线段树上进行修改,这样修改次数是 O ( n ) O(n) O(n) 的。
我们发现每次单点修改太慢了。由于树上到根的一条链上子串的长度都是连续的,并且修改会让很长一段子串的右端点更新为同一个值。如果我们对一段右端点相同的长度连续的子串进行更新的话,显然只需要用一次区间加减的操作。所以我们考虑怎么维护右端点相同的树链。
不知道读者有没有发现,这个更新到根的链的操作很像LCT的access
。如果我们每次操作时在LCT上access
一下,那么树上的每条实链刚好对应着同一个右端点的子串。所以我们不妨用LCT,每次access
的时候把访问到的旧链的子串的贡献去掉,在把新链的贡献加上。
由于LCT跳实链的时间均摊是 O ( log n ) O(\log n) O(logn) 的,所以总复杂度是 O ( n log 2 n ) O(n\log^2n) O(nlog2n),瓶颈在线段树。
代码
#include<bits/stdc++.h>//JZM yyds!!
#define ll long long
#define uns unsigned
#define IF (it->first)
#define IS (it->second)
#define END putchar('\n')
using namespace std;
const int MAXN=400005;
const ll INF=1e17;
inline ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+(s^48),s=getchar();
return f?x:-x;
}
int ptf[50],lpt;
inline void print(ll x,char c='\n'){
if(x<0)putchar('-'),x=-x;
ptf[lpt=1]=x%10;
while(x>9)x/=10,ptf[++lpt]=x%10;
while(lpt)putchar(ptf[lpt--]^48);
if(c>0)putchar(c);
}
inline ll lowbit(ll x){return x&-x;}
//SAM begin
const int CH=26;
struct SAM{
int ch[CH],fa,len;
}sam[MAXN];
int las=1,tot=1;
inline int samadd(int c){
int p=las,np=las=++tot;sam[np].len=sam[p].len+1;
for(;p&&!sam[p].ch[c];p=sam[p].fa)sam[p].ch[c]=np;
if(!p)sam[np].fa=1;
else{int q=sam[p].ch[c],nq;
if(sam[q].len==sam[p].len+1)sam[np].fa=q;
else{
nq=++tot,sam[nq]=sam[q],sam[nq].len=sam[p].len+1;
sam[q].fa=sam[np].fa=nq;
for(;p&&sam[p].ch[c]==q;p=sam[p].fa)sam[p].ch[c]=nq;
}
}return np;
}
//SAM end
//LCT begin
struct spl{
int h[2],fa,a,tg;bool lz;spl(){}
spl(int A){h[0]=h[1]=fa=tg=0,lz=0,a=A;}
}t[MAXN];
inline void cover(int x,int c=0){
if(!x)return;
if(c)t[x].a=t[x].tg=c;
else t[x].lz^=1,swap(t[x].h[0],t[x].h[1]);
}
inline void pushd(int x){
if(t[x].lz)cover(t[x].h[0]),cover(t[x].h[1]),t[x].lz=0;
if(t[x].tg)cover(t[x].h[0],t[x].tg),cover(t[x].h[1],t[x].tg),t[x].tg=0;
}
inline void update(int x){
//do nothing
}
inline bool sd(int x){return x==t[t[x].fa].h[1];}
inline bool isroot(int x){return x!=t[t[x].fa].h[sd(x)];}
inline void lin(int x,int y,bool f){if(x)t[x].h[f]=y;if(y)t[y].fa=x;}
inline void rott(int x){
if(!t[x].fa||isroot(x))return;
bool d1=sd(x),d2=sd(t[x].fa);
int fa=t[x].fa,ff=t[fa].fa,sn=t[x].h[d1^1];
if(isroot(fa))t[x].fa=ff;
else lin(ff,x,d2);
lin(x,fa,d1^1),lin(fa,sn,d1);
update(fa),update(x),update(ff);
}
inline void pushtag(int x){
if(!isroot(x))pushtag(t[x].fa);
pushd(x);
}
inline void splay(int x){
pushtag(x);
while(!isroot(x)){
if(!isroot(t[x].fa)){
if(sd(x)==sd(t[x].fa))rott(t[x].fa);
else rott(x);
}rott(x);
}
}
inline void access(int x){
for(int y=0;x;y=x,x=t[x].fa)splay(x),t[x].h[1]=y,update(x);
}
inline void makeroot(int x){
access(x),splay(x),cover(x);
}
inline void LINK(int x,int y){
makeroot(x),access(y),splay(y);
if((x^y)&&t[x].fa==0)t[x].fa=y;
}
inline void CUT(int x,int y){
makeroot(x),access(x),splay(y);
if(!t[y].h[0]&&t[y].fa==x)t[y].fa=0;
}
//LCT end!!
//zkw begin
int p;
ll f[MAXN<<2],lz[MAXN<<2];
inline void init(int n){
for(p=1;p<n+2;p<<=1);
}
inline void add(int l,int r,ll d){
if(l>r||l<1)return;
int siz=1;
for(l=p+l-1,r=p+r+1;l^1^r;){
if(~l&1)f[l^1]+=siz*d,lz[l^1]+=d;
if(r&1)f[r^1]+=siz*d,lz[r^1]+=d;
l>>=1,r>>=1,siz<<=1;
f[l]=f[l<<1]+f[l<<1|1]+lz[l]*siz;
f[r]=f[r<<1]+f[r<<1|1]+lz[r]*siz;
}
for(l>>=1,siz<<=1;l;l>>=1,siz<<=1)
f[l]=f[l<<1]+f[l<<1|1]+lz[l]*siz;
}
inline ll query(int l,int r){
if(l>r)return 0;
int sl=0,sr=0,siz=1;ll res=0;
for(l=p+l-1,r=p+r+1;l^1^r;){
if(~l&1)res+=f[l^1],sl+=siz;
if(r&1)res+=f[r^1],sr+=siz;
l>>=1,r>>=1,siz<<=1,res+=sl*lz[l]+sr*lz[r];
}sl+=sr;
for(l>>=1;l;l>>=1)res+=sl*lz[l];
return res;
}
//zkw end
inline void acc(int x){
for(int y=0;x;y=x,x=t[x].fa){
splay(x);
int l=t[x].a-sam[x].len+1,r=t[x].a-sam[t[x].fa].len;
add(l,r,-1),t[x].h[1]=y,update(x);
}
}
char s[MAXN];
int n,m,L[MAXN],R[MAXN],id[MAXN];
ll as[MAXN];
vector<int>ad[MAXN];
signed main()
{
scanf("%s",s+1),n=strlen(s+1);
for(int i=1;i<=n;i++)id[i]=samadd(s[i]-'a');
for(int i=2;i<=tot;i++)LINK(i,sam[i].fa);
m=read();
makeroot(1);
for(int i=1;i<=m;i++){
L[i]=read(),R[i]=read();
ad[R[i]].push_back(i);
}
init(n);
for(int i=1;i<=n;i++){
int x=id[i];
acc(x),splay(x),cover(x,i);
add(1,i,1);
for(int y:ad[i])as[y]=query(L[y],R[y]);
}
for(int i=1;i<=m;i++)print(as[i]);
return 0;
}