题目大意
给一个长度为n 的字符串s,每次询问一个串q 问s 有多少个子串与q 循环同构。
s长度,q总长度小于10^6。
分析
就是把q复制一遍,然后再sam上面跑,如果循环了,就不加入答案。
好久没打SAM…需要复习一下。
卡了好久,因为有几句话打反了。
为了程序不出错,最好把root设置成1。
跑的时候注意,用mat记录匹配的字符串长度,注意一定要跑到完全代表s[i~i+len(q)]的点,不然答案会小。
代码
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<set>
#include<bitset>
using namespace std;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
#define cmax(a,b) (a=(a>b)?a:b)
typedef long long ll;
typedef double db;
const int N=2e6+5,mo=1e8;//7707253;
int tr[N][26],fail[N],mx[N],rt,kmp[N],lst[N],n,len,i,j,k,ts,q1,q2,d[N],x,tmp,l,mat,vis[N];
char s[N];
ll ans,cnt[N];
int tt,b[N],first[N],next[N];
void cr(int x,int y)
{
tt++;
b[tt]=y;
next[tt]=first[x];
first[x]=tt;
}
int ins(int lst,int x)
{
int np=++ts,p=lst,nq,q;
mx[np]=mx[p]+1;
cnt[np]=1;
while (p&&!tr[p][x]) tr[p][x]=np,p=fail[p];
if (!p) fail[np]=1;else
{
q=tr[p][x];
if (mx[q]==mx[p]+1) fail[np]=q;else
{
nq=++ts;
memcpy(tr[nq],tr[q],sizeof(tr[q]));
mx[nq]=mx[p]+1;
while (p&&tr[p][x]==q) tr[p][x]=nq,p=fail[p];
fail[nq]=fail[q];
fail[q]=fail[np]=nq;
}
}
return np;
}
int main()
{
freopen("data.in","r",stdin);
//freopen("data.out","w",stdout);
scanf("%s\n",s+1);
len=strlen(s+1);
//root=1;
ts=1;
fail[1]=0;
lst[0]=1;
fo(i,1,len) lst[i]=ins(lst[i-1],s[i]-'a');
fo(i,2,ts) cr(fail[i],i);
q1=0;d[q2=1]=1;
while(q1<q2) for(int p=first[d[++q1]];p;p=next[p]) d[++q2]=b[p];
fd(i,ts,1) cnt[fail[d[i]]]+=cnt[d[i]];
scanf("%d\n",&n);
fo(l,1,n)
{
scanf("%s\n",s+1);
len=strlen(s+1);
fo(i,1,len) s[i+len]=s[i];
ans=0;tmp=1;mat=0;
fo(i,1,len*2)
{
while (tmp&&!tr[tmp][s[i]-'a']) tmp=fail[tmp],mat=mx[tmp];
if (!tmp) tmp=1,mat=0;
else tmp=tr[tmp][s[i]-'a'],mat++;
while (tmp&&mx[fail[tmp]]>=len) tmp=fail[tmp],len=mx[tmp];
if (mat>=len&&vis[tmp]!=l) vis[tmp]=l,ans+=cnt[tmp];
}
printf("%I64d\n",ans);
}
}