题意:
给你一个长度为
n
n
n的字符串,有
m
m
m次询问,每次询问给出一对
(
l
,
r
)
(l,r)
(l,r),问你有多少种把序列划分成三段的方法,使得
S
1...
i
,
S
i
+
1...
j
−
1
,
S
j
.
.
.
n
S_{1...i},S_{i+1...j-1},S_{j...n}
S1...i,Si+1...j−1,Sj...n三段中至少有一点包含子串
S
l
.
.
.
r
S_{l...r}
Sl...r,对于每一次询问回答符合要求的
(
i
,
j
)
(i,j)
(i,j)的对数。
n
<
=
1
e
5
,
m
<
=
3
e
5
n<=1e5,m<=3e5
n<=1e5,m<=3e5,字符集是
0
−
9
0-9
0−9的数字。
题解:
去年一轮省选当场爆零了的题,将近一年过后,感觉自己水平提高并不是太大啊,现在做还是毫无思路呐。可能唯一的进步是照着题解和代码看能差不多看懂了吧,这样今年一定还是要再被吊锤啊。我说一下题解的做法吧。
首先我们考虑对于一次询问,怎么快速找出都有哪些位置出现了当前这个询问串。这个可以用SAM来做,做法是我们建出SAM之后建出它的parent树,我们在加入这个串的时候,记录一下加进去的每一个字符分别对应着哪个节点。这样我们找到询问串的右端点,然后在parent树上倍增,找到一个长度大于等于当前询问串的长度最小的节点,那么在parent树中以这个节点为根的子树内,所有结尾位置再向前延伸到当前询问串的长度之后形成的串都会是一个询问串。我们要维护这个集合,也就是parent树上的每一个节点对应的出现的串的位置,由于串长在询问的时候就已知了,所以我们只维护一个右端点,就可以知道每个串的信息了。我们维护的方式是用一个线段树合并,但是我们在合并的时候会把原来子树的信息给破坏掉,所以我们就把离线下来,对着parent树dfs一遍,一边dfs一边合并,合并完了再回答倍增之后到了当前点的所有询问的答案。
下面再说说怎么求这个答案,要分情况讨论。我们直接求答案是比较难求的,那么正难则反,我们考虑用总的方案数减去分成三段之后都不包含一个完整的 S l . . . r S_{l...r} Sl...r的方案数。总方案数就是 C n − 1 2 C_{n-1}^2 Cn−12,那么我们考虑求分成三段之后不包含一个完整的询问串的方案数。
我们假设询问串在原串中出现了若干次,我们要求分成三段,每一段都不包含一个完整的询问串,那么可以看作把原串从某两个字符之间的位置切两刀,分成三段,这两刀要把所有出现了的串都切成两段。显然如果有大于等于三段互不相交的询问串出现,那么一定没法全部切断。我们把所有出现的询问串按照左端点递增的顺序考虑,给它们的左右端点分别标号为 ( l 1 , r 1 ) , ( l 2 , r 2 ) . . . ( l n , r n ) (l_1,r_1),(l_2,r_2)...(l_n,r_n) (l1,r1),(l2,r2)...(ln,rn),由于串长是相同的,所以不难发现右端点位置也是单调递增的。为了不算重,我们规定第一刀切的位置在第二刀切的位置的左边。而枚举边界 r 1 r_1 r1和 r n r_n rn可以通过在我们之前维护的线段树上二分得到。
先说两种比较特殊的情况。
第一种是我们第一刀一个串都没切到,但是第二刀一次就能切断所有出现的询问串。那么答案就是第一个串的左端点之前的长度乘所有串的交集的长度。
第二种是我们第一刀切断了所有串,第二刀只需要在第一刀后面的位置任意选取就可以了。第一刀可以选的位置是所有串的交集,那么我们切完第一刀之后,随着第一刀位置逐渐向右,第二刀的可行位置是每次不断减少一个的,那么答案是一个公差是 1 1 1(也可是说公差是 − 1 -1 −1)等差数列的和。
然后是一些比较普遍的情况。
第一种情况是,我们第一刀切开了一些串,第二刀切开了其他的所有串。我们可以把上面的这些 l i l_i li和 r i r_i ri都看作一些把原串分成若干段的断点。我们考虑枚举第一刀切开的位置,我们要保证最左边的第一个串被第一刀切到了,不难发现,能被一刀切开的串是连续的一段,于是我们考虑枚举 i i i,设 1 − i 1-i 1−i的第一刀切开的串, i + 1 − n i+1-n i+1−n是第二刀切开的串。那么有 ∑ i = 1 n \sum_{i=1}^n ∑i=1n第 1 1 1次出现与第 i i i次出现的交 ∗ * ∗第 i + 1 i+1 i+1次出现与第 n n n次出现的交。这里的交是指相交部分不被下一个串覆盖情况下的交,这样才能保证不重不漏。
于是我们会发现这种意义下第 1 1 1个串和第 i i i个串的交通常是 l i + 1 − l i l_{i+1}-l_i li+1−li,等价于 r i + 1 − r i r_{i+1}-r_i ri+1−ri。有一个特例,是当前的可行右端点不再是某个串的左端点的前一个位置了,而是第一个串的右端点,此时要取个min。还有一种比较特别的是,当 i = n i=n i=n的时候,我们可能会对应第二种特殊情况,所以单独处理。
然后再考虑第 i + 1 i+1 i+1个串和第 n n n个串的交。通常情况下,这个交是 r i + 1 − l n + 1 r_{i+1}-l_{n}+1 ri+1−ln+1,但是当 i = 1 i=1 i=1的时候,可能对应第一种特殊情况,所以当 i = 1 i=1 i=1的时候我们也特殊处理。
这样,对于除了 i = 1 i=1 i=1和 i = n i=n i=n的部分,我们整理一下式子: ∑ i = 2 n − 1 ( r i + 1 − r i ) ∗ ( r i + 1 − l n − 1 + 1 ) \sum_{i=2}^{n-1}(r_{i+1}-r_i)*(r_{i+1}-l_{n-1}+1) ∑i=2n−1(ri+1−ri)∗(ri+1−ln−1+1),把括号拆开,得 ∑ i = 2 n − 1 ( r i + 1 − r i ) ∗ r i + 1 − ( r i + 1 − r i ) ∗ ( l n − 1 − 1 ) \sum_{i=2}^{n-1}(r_{i+1}-r_i)*r_{i+1}-(r_{i+1}-r_i)*(l_{n-1}-1) ∑i=2n−1(ri+1−ri)∗ri+1−(ri+1−ri)∗(ln−1−1),我们发现对于后面一半,再求和时很多相邻的两项都消掉了,只剩下含 r 2 r_2 r2和 r n − 1 r_{n-1} rn−1的项没被消掉,于是得 ∑ i = 2 n − 1 ( r i + 1 − r i ) ∗ r i + 1 − ( l n − 1 − 1 ) ∗ ( r n − 1 − r 2 ) \sum_{i=2}^{n-1}(r_{i+1}-r_i)*r_{i+1}-(l_{n-1}-1)*(r_{n-1}-r_2) ∑i=2n−1(ri+1−ri)∗ri+1−(ln−1−1)∗(rn−1−r2)。我们考虑在之前的那个线段树上同时维护这个东西,我们分成前后两部分来维护, l n − 1 l_{n-1} ln−1是在线段树上二分出来的,所以后面的只需要维护一个最大值和一个最小值;前面一部分是关于相邻两个的信息,这个在线段树的区间信息合并的时候,由于端点是单调的,所以前半个区间的最大值作为 r i r_i ri,后半个区间的最小值作为 r i + 1 r_{i+1} ri+1来合并就可以了。最大值和最小值同理用这个端点单调的性质来更新。不明白的话可以看看代码或者自己画个图。
这样我们就分析完所有的情况以及步骤了,复杂度 O ( n l o g n + m l o g n ) O(nlogn+mlogn) O(nlogn+mlogn)。
代码细节比较多,写起来也比较长,我的代码有5个k,最重要的也是比较难懂的求答案的部分我写了不少注释,不保证注释都写对了,但是我感觉是我写的对的,可能能帮助理解。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,rt=1,lst=1,len[400010],fa[400010],ch[400010][10],cnt=1,cnt1;
int num,book[400010],root[400010],hed[400010],f[400010][21];
long long ans[300010];
char s[100010];
struct node
{
int l,r;
long long mx,mn,sum;
}tr[4000010];
struct edge
{
int to,next;
}a[800010];
struct qwq
{
int l,r,id;
};
vector<qwq> q[300010];
inline int read()
{
int x=0;
char s=getchar();
while(s>'9'||s<'0')
s=getchar();
while(s>='0'&&s<='9')
{
x=x*10+s-'0';
s=getchar();
}
return x;
}
inline void pushup(int rt)
{
tr[rt].mx=tr[tr[rt].r].mx;
if(!tr[rt].mx)
tr[rt].mx=tr[tr[rt].l].mx;
tr[rt].mn=tr[tr[rt].l].mn;
if(!tr[rt].mn)
tr[rt].mn=tr[tr[rt].r].mn;
tr[rt].sum=tr[tr[rt].l].sum+tr[tr[rt].r].sum;
if(tr[tr[rt].l].mx&&tr[tr[rt].r].mn)
tr[rt].sum+=tr[tr[rt].r].mn*(tr[tr[rt].r].mn-tr[tr[rt].l].mx);
}
inline void update(int &rt,int l,int r,int x)
{
if(!rt)
rt=++num;
if(l==r)
{
tr[rt].mx=x;
tr[rt].mn=x;
tr[rt].sum=0;
return;
}
int mid=(l+r)>>1;
if(x<=mid)
update(tr[rt].l,l,mid,x);
else
update(tr[rt].r,mid+1,r,x);
pushup(rt);
}
inline void insert(int x,int y)
{
int cur=++cnt,pre=lst;
lst=cur;
len[cur]=len[pre]+1;
for(;pre&&!ch[pre][x];pre=fa[pre])
ch[pre][x]=cur;
if(!pre)
fa[cur]=rt;
else
{
int ji=ch[pre][x];
if(len[ji]==len[pre]+1)
fa[cur]=ji;
else
{
int gg=++cnt;
len[gg]=len[pre]+1;
memcpy(ch[gg],ch[ji],sizeof(ch[ji]));
fa[gg]=fa[ji];
fa[ji]=fa[cur]=gg;
for(;pre&&ch[pre][x]==ji;pre=fa[pre])
ch[pre][x]=gg;
}
}
book[y]=cur;
update(root[cur],1,n,y);
}
inline void add(int from,int to)
{
a[++cnt1].to=to;
a[cnt1].next=hed[from];
hed[from]=cnt1;
}
inline void merge(int &x,int y)
{
if(x==0||y==0)
{
x=x+y;
return;
}
merge(tr[x].l,tr[y].l);
merge(tr[x].r,tr[y].r);
pushup(x);
}
inline int lower(int rt,int l,int r,int x)
{
if(tr[rt].mn>=x)
return tr[rt].mn;
int mid=(l+r)>>1;
if(tr[tr[rt].l].mx&&tr[tr[rt].l].mx>=x)
return lower(tr[rt].l,l,mid,x);
else
return lower(tr[rt].r,mid+1,r,x);
}
inline int upper(int rt,int l,int r,int x)
{
if(tr[rt].mx<=x)
return tr[rt].mx;
int mid=(l+r)>>1;
if(tr[tr[rt].r].mn&&tr[tr[rt].r].mn<=x)
return upper(tr[rt].r,mid+1,r,x);
else
return upper(tr[rt].l,l,mid,x);
}
inline node query(int rt,int l,int r,int le,int ri)
{
node x;
if(le>ri)
return x;
if(le<=l&&r<=ri)
{
x.mx=tr[rt].mx;
x.mn=tr[rt].mn;
x.sum=tr[rt].sum;
return x;
}
int mid=(l+r)>>1;
if(mid>=ri)
return query(tr[rt].l,l,mid,le,ri);
else if(mid+1<=le)
return query(tr[rt].r,mid+1,r,le,ri);
else
{
node y=query(tr[rt].l,l,mid,le,ri),z=query(tr[rt].r,mid+1,r,le,ri);
x.mx=z.mx;
if(!x.mx)
x.mx=y.mx;
x.mn=y.mn;
if(!x.mn)
x.mn=z.mn;
x.sum=y.sum+z.sum;
if(y.mx&&z.mn)
x.sum+=z.mn*(z.mn-y.mx);
}
return x;
}
inline long long calc(int rt,int l,int r)
{
if(l==r)
return 0;
int len=r-l;
int r1=tr[rt].mn,l1=r1-len+1,rn=tr[rt].mx,ln=rn-len+1;
int ql,qr;
ql=lower(rt,1,n,ln);//找大于等于ln的第一个数,不是r1就不能一刀切开所有
if(ql!=r1)
ql=upper(rt,1,n,ql-1);//找小于等于,也就是可以当第一刀的第一个右端点的位置
else
ql=0;
qr=upper(rt,1,n,r1+len-1);//如果qr=rn的话就意味着第一刀可以一次切开所有的
long long res=0;
if(ql>qr)
return 0;
if(ql==qr)
{
if(ql==0)
res+=1ll*(l1-2)*(r1-ln+1);//靠前的不切断任何字符串,靠后的切断了所有字符串
else if(qr==rn)//靠前的切断了所有字符串,只要保证第二次在第一次后面切就可以
{
long long x=n-r1,y=n-ln;
if(x>y)
swap(x,y);
res+=(x+y)*(y-x+1)/2;
}//随着切的第一刀往右移,可以作为第二刀的位置依次减1,所以是个等差数列
else
{
int ji=lower(rt,1,n,ql+1);
res+=(min(r1,ji-len)-(ql-len+1)+1)*(ji-ln+1);
//注意左端点要和r1取min来保证第一段被切到了,前半部分就是所有第一刀可行的区间,乘号后面就是后半部分可行的区间
}
}
else//特判ql和qr,其他的一起算
{
node x=query(rt,1,n,ql+1,qr);//+1之后就是不包含ql和qr的情况了
res+=x.sum-1ll*(ln-1)*(x.mx-x.mn);
if(ql==0)
res+=1ll*(l1-2)*(r1-ln+1);
else
{
int ji=lower(rt,1,n,ql+1);
res+=(min(r1,ji-len)-(ql-len+1)+1)*(ji-ln+1);
}
if(qr==rn)
{
long long x=n-r1,y=n-ln;
if(x>y)
swap(x,y);
res+=(x+y)*(y-x+1)/2;
}
else
{
int ji=lower(rt,1,n,qr+1);
res+=(min(r1,ji-len)-(qr-len+1)+1)*(ji-ln+1);
}
}
return res;
}
inline void dfs(int x)
{
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
dfs(y);
merge(root[x],root[y]);
}
int ji=q[x].size();
for(int i=0;i<ji;++i)
ans[q[x][i].id]=1ll*(n-1)*(n-2)/2-calc(root[x],q[x][i].l,q[x][i].r);
}
int main()
{
n=read();
m=read();
scanf("%s",s+1);
for(int i=1;i<=n;++i)
insert(s[i]-'0',i);
for(int i=2;i<=cnt;++i)
{
add(fa[i],i);
f[i][0]=fa[i];
}
for(int j=1;j<=20;++j)
{
for(int i=1;i<=cnt;++i)
f[i][j]=f[f[i][j-1]][j-1];
}
for(int w=1;w<=m;++w)
{
int l=read(),r=read();
int cur=book[r];
for(int i=20;i>=0;--i)
{
if(len[f[cur][i]]>=r-l+1)
cur=f[cur][i];
}
qwq x;
x.l=l;
x.r=r;
x.id=w;
q[cur].push_back(x);
}
dfs(1);
for(int i=1;i<=m;++i)
printf("%lld\n",ans[i]);
return 0;
}