题目大意
给出一个由‘0’~‘9’构成的长度为n的字符串s,f[i]表示s从i断开后(即s[1~n]分成s[1~i]和s[i+1~n]两个串),两个串中本质不同的串的个数(例如一个不同的子串a在两个串出现的总次数为cnt,cnt>1则对f[i]+1)。
输出
∑n−1j=1f[j]∗100013n−j−1mod(109+7)
n<=2*10^5
后缀自动机
求本质不同的串的个数,很容易想到用SAM来做。
先求出原串中本质不同的串的个数,记作ans
正难则反,考虑重新定义f[i],表示有多少个本质不同的串s,其所有出现位置都跨过了i(设其中一个位置的左右端点为l,r,则跨过i当且仅当
l<=i<r
)。
答案就等于
∑n−1j=1(ans−f[j])∗100013n−j−1mod(109+7)
我们考虑一个子串的right集(表示其所有结束位置),设maxr表示该子串的最大结束位置,minr表示最小结束位置,len表示该串的长度。
当maxr-len>=minr时,该子串对于f数组不会产生贡献。
当maxr-len
<
<script id="MathJax-Element-758" type="math/tex"><</script>minr时,该子串对f[maxr-len+1~minr-1]有1的贡献。
然而枚举所有本质不同的子串是n^2级的,显然过不了。
考虑SAM中的一个节点,该节点表示的所有子串的right集是相同的,且这些串的len是一段连续的区间,所以对f的贡献是加一段等差数列再使一段增加同一个值。
一个节点maxr和minr怎么求呢?我们知道该节点的right集=该节点所有儿子节点的right集的并集,所以我们可以从叶子节点往上更新,这个可以以节点的maxlen为关键字排序,再扫一遍更新,而一个节点的minlen显然等于该节点的父亲的maxlen+1。
处理出这些,我们就可以计算f了。
“使一段增加同一个值”我们很容易想到差分,在开头位置+x,在结尾+1的位置-x,再求一遍前缀和即可,然而加一段等差数列要怎么做呢?
如果用数据结构去维护,则复杂度要带个log。
我们可以继续考虑差分,设g[i]=f[i]-f[i-1],加一段等差数列,则对于g[i]相当于使一段增加同一个值,于是可以根据上述做法,求出g再求出f。
代码
#include<cstring>
#include<algorithm>
#include<cstdio>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fod(i,a,b) for(i=a;i>=b;i--)
#define ll long long
using namespace std;
const int maxn=200000+5;const int mo=1000000007;
struct sam{
int len,son[10],pre,mx,lem,mn;
} s[maxn*2];
int i,j,n,num,last,id[maxn*2];
ll ans,an,w[maxn],f[maxn],g[maxn];
char c[maxn];
void add(int x){
int np=++num,p=last;
s[np].len=s[p].len+1;s[np].mx=s[np].mn=i;
while (p!=-1&&s[p].son[x]==0) s[p].son[x]=np,p=s[p].pre;
if (p==-1) s[np].pre=0;else{
int q=s[p].son[x];
if (s[q].len==s[p].len+1) s[np].pre=q;else{
int nq=++num;
s[nq]=s[q];
s[nq].len=s[p].len+1;
s[q].pre=s[np].pre=nq;
while (p!=-1&&s[p].son[x]==q) s[p].son[x]=nq,p=s[p].pre;
}
}
last=np;
}
int main(){
int t;scanf("%d",&t);
while (t,t--){
scanf("%d",&n);
scanf("%s",c+1);
memset(s,0,sizeof(s));
s[0].pre=-1;num=last=0;an=0;
fo(i,1,n) {
add(c[i]-'0');
int len=s[last].len-s[s[last].pre].len;an+=len;
}
memset(w,0,sizeof(w));
fo(i,1,num) w[s[i].len]++;
fo(i,1,n) w[i]+=w[i-1];
fod(i,num,1) id[w[s[i].len]--]=i;
fod(i,num,1) {
int x=id[i],fa=s[x].pre;
s[fa].mx=max(s[fa].mx,s[x].mx);
s[fa].mn=min(s[fa].mn,s[x].mn);
s[x].lem=s[fa].len+1;
}
fo(i,1,n) g[i]=f[i]=0;
fo(i,1,num) {
int mx=s[i].mx,mn=s[i].mn,len=s[i].len,lem=s[i].lem;
if (mx-len<mn){
lem=max(lem,mx-mn+1);
g[mx-len+1]++,g[mx-lem+2]--;
f[mn]-=len-lem+1;
}
}
fo(i,1,n) g[i]+=g[i-1];
fo(i,1,n) f[i]+=f[i-1]+g[i];
w[0]=1;
ans=0;
fo(i,1,n) w[i]=w[i-1]*100013%mo;
fo(i,1,n-1) ans=(ans+(an-f[i])%mo*w[n-i-1]%mo)%mo;
printf("%lld\n",ans);
}
}