题意
对于一个字符串|S|,我们定义fail[i],表示最大的x使得S[1…x]=S[i-x+1…i],满足(x < i)
显然对于一个字符串,如果我们将每个0<=i<=|S|看成一个结点,除了i=0以外i向fail[i]连边,这是一颗树的形状,根是0
我们定义这棵树是G(S),设f(S)是G(S)中除了0号点以外所有点的深度之和,其中0号点的深度为-1
定义key(S)等于S的所有非空子串S’的f(S’)之和
给定一个字符串S,现在你要实现以下几种操作:
1.在S最后面加一个字符
2.询问key(S)
善良的出题人不希望你的答案比long long大,所以你需要将答案对1e9+7取模
第一行一个正整数Q
Q<=10^5
第二行一个长度为Q的字符串S
分析
首先我们来分析一下这道题的各种性质。
f(S)就相当于每个前缀在原串中出现的次数(不包括前缀本身)。
不难发现key(S)就相当于每个非空子串S的f(S)*(n-pos[S]+1)的和,pos[S]表示S结尾字符的位置。
假设我们维护出了当前S每个子串S的f(S)的和,设为tot,当前的答案为ans,当前在结尾新加了一个字符,考虑维护新的答案。
那么新答案必然是ans+tot+sigma(f(每个后缀))。
问题在于如何快速求出sigma(f(每个后缀))。
根据前面推出的性质不难发现该值等于每个后缀在原串中出现次数的和(不包括后缀本身)。
考虑离线建出sam,那么新加入字符的所有后缀串就相当于其在sam上的对应节点到parents树的根路径上的所有串。
我们可以用树链剖分+线段树来维护这棵parents树,每个节点维护该节点对应的子串在元素列中出现的次数,那么就得到了一下算法:
加入一个新字符,设val=其在sam上对应节点到根路径上的权值和,tot+=val,ans+=tot。
把对应节点到根路径上的每个节点加上该节点对应的字符串数量。
当然也可以用lct来维护这个东西。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N=100005;
const int MOD=1000000007;
int n,cnt,sz,tim,last[N*2],mx[N*2],fa[N*2],ch[N*2][26],dep[N*2],size[N*2],top[N*2],pos[N*2],num[N],ls,bel[N*2];
char str[N];
struct edge{int to,next;}e[N*2];
struct tree{int s,tag,val;}t[N*10];
void extend(int x)
{
int p,q,np,nq;
p=ls;ls=np=++cnt;mx[np]=mx[p]+1;
for (;p&&!ch[p][x];p=fa[p]) ch[p][x]=np;
if (!p) fa[np]=1;
else
{
q=ch[p][x];
if (mx[q]==mx[p]+1) fa[np]=q;
else
{
nq=++cnt;mx[nq]=mx[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
for (;ch[p][x]==q;p=fa[p]) ch[p][x]=nq;
}
}
}
void addedge(int u,int v)
{
e[++sz].to=v;e[sz].next=last[u];last[u]=sz;
}
void dfs1(int x)
{
dep[x]=dep[fa[x]]+1;size[x]=1;
for (int i=last[x];i;i=e[i].next)
{
dfs1(e[i].to);
size[x]+=size[e[i].to];
}
}
void dfs2(int x,int chain)
{
top[x]=chain;pos[x]=++tim;bel[tim]=x;int k=0;
for (int i=last[x];i;i=e[i].next)
if (size[e[i].to]>size[k]) k=e[i].to;
if (!k) return;
dfs2(k,chain);
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=k) dfs2(e[i].to,e[i].to);
}
void pushdown(int d,int l,int r)
{
if (l==r||!t[d].tag) return;
int w=t[d].tag;t[d].tag=0;
t[d*2].s+=(LL)t[d*2].val*w%MOD;t[d*2].s-=t[d*2].s>=MOD?MOD:0;
t[d*2+1].s+=(LL)t[d*2+1].val*w%MOD;t[d*2+1].s-=t[d*2+1].s>=MOD?MOD:0;
t[d*2].tag+=w;t[d*2+1].tag+=w;
}
int ins(int d,int l,int r,int x,int y)
{
if (x>y) return 0;
pushdown(d,l,r);
if (l==x&&r==y)
{
int ans=t[d].s;
t[d].s+=t[d].val;t[d].s-=t[d].s>=MOD?MOD:0;
t[d].tag++;
return ans;
}
int mid=(l+r)/2;
int ans=ins(d*2,l,mid,x,min(y,mid))+ins(d*2+1,mid+1,r,max(x,mid+1),y);
ans-=ans>=MOD?MOD:0;
t[d].s=t[d*2].s+t[d*2+1].s;
t[d].s-=t[d].s>=MOD?MOD:0;
return ans;
}
void build(int d,int l,int r)
{
if (l==r)
{
t[d].val=mx[bel[l]]-mx[fa[bel[l]]];
return;
}
int mid=(l+r)/2;
build(d*2,l,mid);build(d*2+1,mid+1,r);
t[d].val=t[d*2].val+t[d*2+1].val;t[d].val-=t[d].val>=MOD?MOD:0;
}
int solve(int x,int y)
{
int ans=0;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=ins(1,1,cnt,pos[top[x]],pos[x]);
ans-=ans>=MOD?MOD:0;
x=fa[top[x]];
}
if (dep[x]<dep[y]) swap(x,y);
ans+=ins(1,1,cnt,pos[y]+1,pos[x]);
ans-=ans>=MOD?MOD:0;
return ans;
}
int main()
{
scanf("%d%s",&n,str+1);
ls=cnt=1;
for (int i=1;i<=n;i++) extend(str[i]-'a'),num[i]=ls;
for (int i=2;i<=cnt;i++) addedge(fa[i],i);
dfs1(1);
dfs2(1,1);
build(1,1,cnt);
int ans=0,tot=0;
for (int i=1;i<=n;i++)
{
int val=solve(num[i],1);
tot+=val;tot-=tot>=MOD?MOD:0;
ans+=tot;ans-=ans>=MOD?MOD:0;
printf("%d\n",ans);
}
return 0;
}