题目大意:求若干个后缀两两间的LCP之和。后缀数总和不大于3*10^6;
后缀自动机加虚树。
看这种询问有多,总点不太多的又在一棵树上可以考虑虚树。
把原串反过来,等于求前缀的最长公共后缀。
首先有个显然的性质是两个前缀的LCS是它们在parent树上的lca的max。如果不了解的可以先做这题:点击打开链接
建虚树后,跟bzoj3238一样,在树上统计答案就行了。
ans要除二,因为两两只算一次。
好像%不%都无所谓,答案没那么大。
code:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#define LL long long
using namespace std;
struct SAM{
int par,max,a[30];
}sam[1000010];int root=1,tot=1,tail=1,pre[500010];
int ys[1000010],z=0,son[1000010],h[500010];
int check[1000010],tim=0,num=0;
LL ans;
struct trnode{
int dep,fa[25];
}tr[1000010];
struct node{
int y,next;
}a[1000010];int last[1000010],len=0,n,m;
char s[500010];
void ins(int x,int y)
{
a[++len].y=y;
a[len].next=last[x];last[x]=len;
}
void addsam(int c,int len)
{
int p=tail,np=++tot;
sam[np].max=len;
for(;p&&!sam[p].a[c];p=sam[p].par) sam[p].a[c]=np;
tail=np;
if(!p) sam[np].par=root;
else
{
int q=sam[p].a[c];
if(sam[q].max==sam[p].max+1) sam[np].par=q;
else
{
int nq=++tot;sam[nq]=sam[q];
sam[nq].max=sam[p].max+1;
sam[q].par=sam[np].par=nq;
for(;p&&sam[p].a[c]==q;p=sam[p].par) sam[p].a[c]=nq;
}
}
}
int findlca(int x,int y)
{
if(tr[x].dep<tr[y].dep) swap(x,y);
for(int i=19;i>=0;i--)
if((1<<i)<=tr[x].dep-tr[y].dep) x=tr[x].fa[i];
if(x==y) return x;
for(int i=19;i>=0;i--)
if((1<<i)<=tr[x].dep&&tr[x].fa[i]!=tr[y].fa[i])
x=tr[x].fa[i],y=tr[y].fa[i];
return tr[x].fa[0];
}
void dfs(int x,int fa)
{
ys[x]=++z;
tr[x].dep=tr[fa].dep+1;tr[x].fa[0]=fa;
for(int i=1;(1<<i)<=tr[x].dep;i++)
tr[x].fa[i]=tr[tr[x].fa[i-1]].fa[i-1];
for(int i=last[x];i;i=a[i].next)
dfs(a[i].y,x);
}
bool cmp(int a,int b){return ys[a]<ys[b];}
int sta[1000010],top=0;
void dfs(int x)
{
if(check[x]==tim) son[x]=1;
else son[x]=0;
int totson=0;
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
dfs(y);totson+=son[y];
}
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
ans+=(LL)son[y]*(totson-son[y])*sam[x].max;
}
ans+=(LL)totson*son[x]*2*sam[x].max;
last[x]=0;
son[x]+=totson;
}
void build()
{
top=0;sta[++top]=1;sort(h+1,h+num+1,cmp);
len=0;
for(int i=1;i<=num;i++)
{
int x=h[i],lca=findlca(x,sta[top]);
if(lca==sta[top]) sta[++top]=x;
else
{
while(top>=2&&tr[sta[top-1]].dep>=tr[lca].dep)
ins(sta[top-1],sta[top]),top--;
if(lca!=sta[top])
{
ins(lca,sta[top]);
sta[top]=lca;
}
sta[++top]=x;
}
}
for(int i=1;i<top;i++) ins(sta[i],sta[i+1]);
ans=0;dfs(1);
printf("%lld\n",ans/2);
}
int main()
{
scanf("%d %d",&n,&m);
scanf("%s",s+1);
for(int i=n;i>=1;i--) addsam(s[i]-'a',n-i+1);
int x=root;
for(int i=n;i>=1;i--)
x=sam[x].a[s[i]-'a'],pre[n-i+1]=x;
for(int i=2;i<=tot;i++) ins(sam[i].par,i);
tr[0].dep=-1;dfs(1,0);
memset(last,0,sizeof(last));
while(m--)
{
int k;scanf("%d",&k);
tim++;num=0;
for(int i=1;i<=k;i++)
{
int x;scanf("%d",&x);
x=n-x+1;
if(check[pre[x]]!=tim)
{
check[pre[x]]=tim;
h[++num]=pre[x];
}
}
build();
}
}