=== ===
这里放传送门
=== ===
题解
首先可以想到如果能求出每一个子串在三个串中分别的出现次数 (a,b,c) ,那么答案就是把每个子串的 a,b,c 都乘起来再加起来。对于求所有子串相关的问题很容易想到后缀自动机,而这个题有多个串那么就是广义后缀自动机。因为广义后缀自动机可以识别这三个串的所有子串,并且相同的子串会记录在相同的节点上。那么只要建立自动机以后每个节点开三个域,对三个串分别统计它Right集合的大小就可以了。注意最后统计答案的时候不能像求最大值一样用f[i]更新f[i-1],因为这样会出现重复。所以每个节点的答案必须严格加到 [Min(s),Max(s)] 这一段区间上去。这里使用了差分的方法。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const long long Mod=1e9+7;
char S[300010];
int len,cnt,Min,v[1000010],b[1000010];
long long f[300010];
struct Node{
Node *ch[27],*fa;
int step,R[5];
Node();
}*null,*Root,t[1000010],*last,*p,*q,*np,*nq;
Node::Node(){
for (int i=0;i<=26;i++) ch[i]=null;
fa=null;step=0;
for (int i=0;i<5;i++) R[i]=0;
}
Node* New(){++cnt;t[cnt]=Node();return t+cnt;}
void insert(int x,int id){
p=last;
if (p->ch[x]!=null){
q=p->ch[x];
if (q->step==p->step+1){
last=q;q->R[id]++;return;
}//如果有这个节点就直接跳转
nq=New();nq->step=p->step+1;nq->R[id]++;
memcpy(nq->ch,q->ch,sizeof(q->ch));
nq->fa=q->fa;q->fa=nq;
while (p->ch[x]==q){p->ch[x]=nq;p=p->fa;}
last=nq;return;
}
last=np=New();
np->step=p->step+1;np->R[id]=1;
while (p->ch[x]==null&&p!=null){
p->ch[x]=np;p=p->fa;
}
if (p==null){np->fa=Root;return;}
q=p->ch[x];
if (q->step==p->step+1){np->fa=q;return;}
nq=New();nq->step=p->step+1;
memcpy(nq->ch,q->ch,sizeof(q->ch));
nq->fa=q->fa;q->fa=np->fa=nq;
while (p->ch[x]==q){p->ch[x]=nq;p=p->fa;}
}
void Getord(){
for (int i=1;i<=cnt;i++) ++b[t[i].step];
for (int i=1;i<=cnt;i++) b[i]+=b[i-1];
for (int i=cnt;i>=1;i--) v[b[t[i].step]--]=i;
for (int i=cnt;i>=1;i--){
Node *now=t+v[i];
for (int j=1;j<=3;j++)
now->fa->R[j]+=now->R[j];
}
}
int main()
{
null=new Node;*null=Node();
Root=last=New();
Min=0x7fffffff;
for (int i=1;i<=3;i++){
gets(S);len=strlen(S);
Min=min(Min,len);
last=Root;
for (int j=0;j<len;j++)
insert(S[j]-'a',i);
}
Getord();
for (int i=1;i<=cnt;i++){
long long x=1;
Node *now=t+i;
int v=now->step,L,R;
if (now->step==0) continue;
for (int j=1;j<=3;j++)
x=(x*(long long)now->R[j])%Mod;
L=now->fa->step+1;R=now->step;
f[L]=(f[L]+x)%Mod;f[R+1]=(f[R+1]-x)%Mod;
}
for (int i=1;i<=cnt;i++) f[i]=(f[i]+f[i-1])%Mod;
for (int i=1;i<=Min;i++)
printf("%I64d%c",(f[i]+Mod)%Mod," \n"[i==Min]);
return 0;
}