题目
你有一个字符串。
你需要支持两种操作。
1:在字符串的末尾插入一个字符c
2:询问当前字符串的[l,r]子串中的不同子串个数
强制在线
n,m<=50000
思路
暴力:
直接用SAM是O(N3)的
正解:
不难发现这题的操作 1 就是 SAM 的 extend 过程。
设进行一次 1 操作后字符串的长度为 len,则当次 extend 会使字符串增加以第 len 位为结尾的所有子串。
这些子串对应的就是 SAM 的 parents 树上一个叶子节点到根的链,且那个叶子节点就是你新建的表示子串 S1⋯len 的节点。
于是我们考虑用 LCT 动态维护 SAM 的 parents 树。
(好恶心啊,现场只有神仙才能做出来吧)
把last相同的放在同一颗splay里
我们考虑如何维护某个区间 [l,r] 中的不同子串数量。
这个问题可以简化为给你 n 个数,每次询问某个区间 [l,r] 中有多少个不同的数。
于是,容斥!!!用总数量减去那些非最后一次出现的数。
考虑预处理答案,从前往后依次加入每个数,加入第 i 个数即将主席树的第 i 个版本的第 i 位加 1。若该数在之前出现过,设上一次出现的位置为 lst,则在主席树的第 i 个版本的第 j 位减 1。
所以区间内不同子串数量也可以类似地用主席树解决。
于是,码码码!!!!!
好长,200多行
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+77;
int n,m,maxlen,rt[N][2];
struct SegmentTree
{
int ls[N*100],rs[N*100],sz,siz[N*100];
ll sumv[N*100];
int cpyNode(int pre)
{
int x=++sz;
ls[x]=ls[pre];
rs[x]=rs[pre];
sumv[x]=sumv[pre];
siz[x]=siz[pre];
return x;
}
void update(int & o,int pre,int l,int r,int ql,int qr,int v)
{
o=cpyNode(pre);
if(ql==l && qr==r) { sumv[o]+=v,siz[o]++; return; }
int mid=(l+r)>>1;
if(qr<=mid)
update(ls[o],ls[pre],l,mid,ql,qr,v);
else if(ql>mid)
update(rs[o],rs[pre],mid+1,r,ql,qr,v);
else
update(ls[o],ls[pre],l,mid,ql,mid,v),
update(rs[o],rs[pre],mid+1,r,mid+1,qr,v);
}
int querySize(int o,int l,int r,int p)
{
if(l==r) return siz[o];
int mid=(l+r)>>1;
int Ans=siz[o];
if(p<=mid) Ans+=querySize(ls[o],l,mid,p);
else Ans+=querySize(rs[o],mid+1,r,p);
return Ans;
}
ll querySum(int o,int l,int r,int p)
{
if(l==r) return sumv[o];
int mid=(l+r)>>1;
ll Ans=sumv[o];
if(p<=mid) Ans+=querySum(ls[o],l,mid,p);
else Ans+=querySum(rs[o],mid+1,r,p);
return Ans;
}
}rap;
namespace LCT
{
static const int SIZE=2e5+77;
int ch[SIZE][2],fa[SIZE];
int pos[SIZE],len[SIZE];
bool isrt(int x) { return !fa[x] || (ch[fa[x]][0] != x && ch[fa[x]][1] != x); }
bool c(int x) { return ch[fa[x]][1]==x; }
void rotate(int x)
{
int f=fa[x],p=fa[f],d=c(x);
if(!isrt(f)) ch[p][c(f)]=x;
fa[x]=p;
ch[f][d]=ch[x][d^1]; fa[ch[f][d]]=f;
ch[x][d^1]=f; fa[f]=x;
}
int findrt(int x) { return !isrt(x) ? findrt(fa[x]) : x; }
void splay(int x)
{
swap(pos[findrt(x)],pos[x]);
while(!isrt(x))
{
int f=fa[x];
if(!isrt(f))
{
if(c(f)==c(x)) rotate(f);
else rotate(x);
}
rotate(x);
}
}
void access(int x,int i)
{
rt[i][0]=rt[i-1][0];
rt[i][1]=rt[i-1][1];
for(int v=0; x; v=x,x=fa[x])
{
splay(x);
if(len[x] && pos[x])
{
int pl=pos[x]-len[x]+1,pr=pos[x]-len[fa[x]];
if(pl>1) rap.update(rt[i][0],rt[i][0],1,maxlen,1,pl-1,pr-pl+1);
rap.update(rt[i][1],rt[i][1],1,maxlen,pl,pr,pr);
}
if(ch[x][1]) pos[ch[x][1]]=pos[x];
ch[x][1]=v;
pos[v]=0;
pos[x]=i;
}
}
void cut(int x)
{
splay(x);
pos[ch[x][0]]=pos[x];
fa[ch[x][0]]=fa[x];
ch[x][0]=0;
}
void link(int x,int y)
{
splay(x);
fa[x]=y;
}
}
struct SAM
{
static const int SIZE=2e5+77;
int nxt[SIZE][26],fa[SIZE],sz,pos[SIZE],len[SIZE],rt,last;
SAM() { init(); }
int newnode(int l,int id)
{
memset(nxt[sz],0,sizeof(nxt[sz]));
pos[sz]=id;
fa[sz]=0;
len[sz]=l;
return sz++;
}
void init()
{
sz=1;
rt=newnode(0,0);
last=rt;
}
void add(char x,int i)
{
int c=x-'a';
int now=newnode(len[last]+1,i);
LCT::len[now]=len[now];
int p=last;
while(p && !nxt[p][c])
{
nxt[p][c]=now;
p=fa[p];
}
if(!p) LCT::link(now,fa[now]=rt);
else
{
int q=nxt[p][c];
if(len[p]+1==len[q])
LCT::link(now,fa[now]=q);
else
{
int u=newnode(len[p]+1,pos[q]);
LCT::len[u]=len[u];
for(int i=0; i<26; i++) nxt[u][i]=nxt[q][i];
LCT::cut(q);
LCT::pos[u]=LCT::pos[q];
LCT::link(u,fa[u]=fa[q]);
LCT::link(now,fa[now]=u);
LCT::link(q,fa[q]=u);
while(nxt[p][c]==q)
{
nxt[p][c]=u;
p=fa[p];
}
}
}
LCT::access(now,i);
last=now;
}
}cxk;
char str[N];
int type;
int main()
{
scanf("%d",&type);
scanf("%s %d",str+1,&m);
n=strlen(str+1);
maxlen=n+m;
for(int i=1; i<=n; i++)
cxk.add(str[i],i);
ll lastans=0;
for(int i=1; i<=m; i++)
{
int opt;
scanf("%d",&opt);
if(opt==1)
{
char ch[5];
scanf("%s",ch);
str[++n]=(ch[0]-'a'+type*lastans)%26+'a';
cxk.add(str[n],n);
}
else
{
int l,r;
scanf("%d%d",&l,&r);
l=(l-1+lastans*type)%n+1;
r=(r-1+lastans*type)%n+1;
ll val1=rap.querySum(rt[r][0],1,maxlen,l);
ll val2=rap.querySum(rt[r][1],1,maxlen,l);
ll val3=rap.querySize(rt[r][1],1,maxlen,l);
lastans=1ll*(r-l+1)*(r-l+2)/2ll;
lastans-=val1;
lastans-=val2-1ll*(l-1)*val3;
printf("%lld\n",lastans);
}
}
}