题意
给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?
1<=n,k<=10^5,所有字符串总长不超过10^5,字符串只包含小写字母。
分析
一开始的想法是,把所有串连起来建sam,然后用线段树合并来维护。
看了题解,发现可以建广义sam,然后用set启发式合并即可。虽然时间复杂度比较大,但是代码量比较小,于是果断去打了一发(懒癌晚期)。
一开始按标打广义sam发现过不了,去问了一发栋爷才知道广义sam不是这么建的。然后就去学习了一发。
那么这题我们就可以预处理处sam上每个节点的right集中包含了多少个字符串,和每个节点的第一个包含字符串不小于k的父节点。然后再把每个串在sam上跑一边,统计一下即可。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<set>
using namespace std;
typedef long long LL;
const int N=100005;
const int M=N*10;
int n,m,last,cnt,num[M],ch[M][26],fa[M],mx[M],b[N],c[M],pts[M],s[N],l[N],r[N],size[M],bel[M];
char str[N];
set<int> a[M];
void ins(int x,int id)
{
if (ch[last][x])
{
int p=last,q=ch[last][x];
if (mx[q]==mx[p]+1) last=ch[last][x],a[last].insert(id);
else
{
int nq=++cnt;mx[nq]=mx[p]+1;num[nq]=nq;a[nq].insert(id);
memcpy(ch[nq],ch[q],sizeof(ch[q]));
fa[nq]=fa[q];
last=fa[q]=nq;
for (;ch[p][x]==q;p=fa[p]) ch[p][x]=nq;
a[last].insert(id);
}
return;
}
int p,q,np,nq;
p=last;last=np=++cnt;mx[np]=mx[p]+1;a[np].insert(id);num[np]=np;bel[np]=id;
for (;!ch[p][x]&&p;p=fa[p]) ch[p][x]=np;
if (!p) fa[np]=1;
else
{
q=ch[p][x];
if (mx[q]==mx[p]+1) fa[np]=q;
else
{
nq=++cnt;mx[nq]=mx[p]+1;num[nq]=nq;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
for (;ch[p][x]==q;p=fa[p]) ch[p][x]=nq;
}
}
}
void merge(int x,int y)
{
if (a[x].empty()||a[y].empty()) return;
for (set<int>::iterator it=a[y].begin();it!=a[y].end();it++) a[x].insert(*it);
}
void prework()
{
for (int i=1;i<=cnt;i++) b[mx[i]]++;
for (int i=1;i<=r[n];i++) b[i]+=b[i-1];
for (int i=1;i<=cnt;i++) c[b[mx[i]]--]=i;
for (int i=cnt;i>=1;i--)
{
size[c[i]]=a[num[c[i]]].size();
if (!fa[c[i]]) continue;
int a1=num[c[i]],a2=num[fa[c[i]]];
if (a[a1].size()<a[a2].size()) merge(a2,a1);
else num[fa[c[i]]]=a1,merge(a1,a2);
}
for (int i=1;i<=cnt;i++)
if (size[fa[c[i]]]>=m) pts[c[i]]=fa[c[i]];
else pts[c[i]]=pts[fa[c[i]]];
}
LL solve(int l,int r)
{
int now=1,len=0;
LL ans=0;
for (int i=l;i<=r;i++)
{
if (ch[now][s[i]]) now=ch[now][s[i]],len++;
else
{
for (;!ch[now][s[i]]&&now;now=fa[now]);
if (!now) now=1,len=0;
else len=mx[now]+1,now=ch[now][s[i]];
}
if (size[now]>=m) ans+=len;
else ans+=mx[pts[now]];
}
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
last=cnt=1;num[1]=1;
for (int i=1;i<=n;i++)
{
scanf("%s",str);
int len=strlen(str);
l[i]=r[i-1]+1;r[i]=l[i]+len-1;
last=1;
for (int j=0;j<len;j++)
{
ins(str[j]-'a',i);
s[j+l[i]]=str[j]-'a';
}
}
prework();
for (int i=1;i<=n;i++) printf("%lld ",solve(l[i],r[i]));
return 0;
}