http://codeforces.com/problemset/problem/235/C
陈立杰出的后缀自动机,过的人挺少,不过还算是一道中规中矩的后缀自动机吧。
题目大意:给一个字符串S,再给一个字符串T,设T的长度为len,问T的循环串在S中出现的次数,这里循环串的定义是:对于一个长度为len的字符串,我们把它首尾相接,然后从任意位置开始走len步所得到的串我们叫做T的循环串。如abaa的循环串有 abaa,baaa,aaab,aaba。(注意如果重复只算一次。比如aaa的循环串只有一个aaa)
思路:对于字符串S,我们构造S的后缀自动机,然后对于每一个字符串T,我们设T'为T去掉最后一个字符所得到的字符串,然后构造TT',在S的后缀自动机上进行匹配,类似于LCS的做法,我们可以算出对于TT'的每一个位置,可以匹配的最大总长度,那么当匹配长度大于等于len时(这里的len为T的长度),设当前所在状态为p,则我们可以根据par链找到匹配长度为len时所对应的状态,设为q,则我们设状态q所表示的子串出现的次数为q->num,则ans+=q->num,num的计算还是通过拓扑排序,自底向上求即可,注意这里有可能有重复,所以我们还得在每一个状态里设一个标记flag,表示当前状态是否被计算过,若已计算过则跳过即可。代码如下:
#include <iostream>
#include <string.h>
#include <stdio.h>
#define maxn 2000100
#define Smaxn 26
using namespace std;
struct node
{
node *par,*go[Smaxn];
int flag;
int num;
int val;
}*root,*tail,que[maxn],*top[maxn];
int tot;
char str[maxn];
void add(int c,int l)
{
node *p=tail,*np=&que[tot++];
np->val=l;
while(p&&p->go[c]==NULL)
p->go[c]=np,p=p->par;
if(p==NULL) np->par=root;
else
{
node *q=p->go[c];
if(p->val+1==q->val) np->par=q;
else
{
node *nq=&que[tot++];
*nq=*q;
nq->val=p->val+1;
np->par=q->par=nq;
while(p&&p->go[c]==q) p->go[c]=nq,p=p->par;
}
}
tail=np;
}
int c[maxn],len;
void init()
{
len=1;
tot=0;
memset(que,0,sizeof(que));
root=tail=&que[tot++];
}
char st[2000100];
void solve(int n)
{
int i,j;
memset(c,0,sizeof(c));
for(i=0;i<tot;i++)
c[que[i].val]++;
for(i=1;i<len;i++)
c[i]+=c[i-1];
for(i=0;i<tot;i++)
top[--c[que[i].val]]=&que[i];
for(node *p=root;;p=p->go[str[p->val]-'a'])
{
p->num=1;
if (p->val==len-1)break;
}
for(i=tot-1;i>=0;i--)
{
node *p=top[i];
if(p->par)
{
p->par->num+=p->num;
}
}
int tmp=0;
node *p=root;
for(i=1;i<=n;i++)
{
long long ans=0;
scanf("%s",st);
int l=strlen(st);
memcpy(st+l,st,l);
int ll=2*l;
st[ll-1]='\0';
for(j=0;j<ll-1;j++)
{
int x=st[j]-'a';
if(p->go[x])
{
tmp++;
p=p->go[x];
}
else
{
while(p&&p->go[x]==NULL)
p=p->par;
if(p)
{
tmp=p->val+1;
p=p->go[x];
}
else
{
tmp=0;
p=root;
}
}
if(j>=l-1&&tmp>=l)
{
node *q=p;
while(1)
{
if(l>=q->par->val+1&&l<=q->val)
break;
q=q->par;
}
if(q->flag!=i)
{
ans+=q->num;
q->flag=i;
}
}
}
printf("%I64d\n",ans);
}
}
int main()
{
freopen("dd.txt","r",stdin);
scanf("%s",str);
init();
int i,l=strlen(str);
for(i=0;i<l;i++)
{
add(str[i]-'a',len++);
}
int n;
scanf("%d",&n);
solve(n);
return 0;
}