这题调的我两眼发黑。。。
首先想 S A M SAM SAM建出来就是反串的后缀树,那么把原串反转一下,求后缀的 l c p lcp lcp就变成了求前缀的最长后缀,在 S A M SAM SAM上就是两个代表节点 l c a lca lca的 l e n len len,用这些关键点和他们的 l c a lca lca建出虚树,然后在虚树上跑,设 s i z a [ u ] , s i z b [ u ] siz_a[u],siz_b[u] siza[u],sizb[u]分别表示 u u u子树内 a , b a,b a,b节点的个数,则 a n s + = s i z a [ u ] ∗ s i z b [ u ] ∗ ( l u − l f a ) ans+=siz_a[u]*siz_b[u]*(l_u-l_{fa}) ans+=siza[u]∗sizb[u]∗(lu−lfa)
然后就是虚树的基本操作,把关键点按 d f n dfn dfn排序,两两求 l c a lca lca加入关键点中,然后再排序后建虚树,注意建树弹栈的时候要判断当前点是不是这个点的祖先,要用 l c a lca lca判断,还有清空的时候不能遍历 c n t cnt cnt不然就 t l e tle tle了,还有数组要开一倍
细节非常多了可以说
像我这样码力弱的一点一点调过来已经写的不成样子了,将就着看一眼代码吧
(本来打开这题是想做一下虚树结果发现 S A M SAM SAM全忘了于是从头开始学 S A M SAM SAM。。。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#define N 400005
#define LL long long
using namespace std;
template<class T>inline void rd(T &x){
x=0; short f=1; char c=getchar();
while(c<'0' || c>'9') f=c=='-'?-1:1,c=getchar();
while(c<='9' && c>='0') x=x*10+c-'0',c=getchar();
x*=f;
}
int n,q,id[N],sza,szb,qu[N],tot,rt,siz1[N],siz2[N],ty1[N],ty2[N];
int cnt,head[N],nxt[N<<1],to[N<<1],dep[N],f[N][20],dfn[N],num;
char s[N];
bool vis[N];
LL ans;
vector<int> vec[N];
struct qwq{
int pos,key;
qwq(const int pp=0,const int kk=0){pos=pp,key=kk;}
inline bool operator <(const qwq &x) const{
return key<x.key;
}
}a[N<<1];
inline void add(int x,int y){
to[++cnt]=y,nxt[cnt]=head[x],head[x]=cnt;
to[++cnt]=x,nxt[cnt]=head[y],head[y]=cnt;
}
struct SAM{
int lst,cnt,ch[N<<1][26],fa[N<<1],l[N<<1];
inline void insert(int c,int pos){
int p=lst,np=++cnt; id[pos]=cnt; lst=np; l[np]=l[p]+1;
while(p&&!ch[p][c]) ch[p][c]=np,p=fa[p];
if(!p) fa[np]=1;
else{
int q=ch[p][c];
if(l[p]+1==l[q]) fa[np]=q;
else{
int nq=++cnt; l[nq]=l[p]+1;
memcpy(ch[nq],ch[q],sizeof ch[q]);
fa[nq]=fa[q],fa[q]=fa[np]=nq;
while(ch[p][c]==q) ch[p][c]=nq,p=fa[p];
}
}
}
inline void build(){
scanf("%s",s+1); int len=strlen(s+1);
lst=cnt=1; for(int i=len;i;i--) insert(s[i]-'a',i);
for(int i=1;i<=cnt;i++)
if(fa[i]) add(fa[i],i);
}
}sam;
void dfs(int u,int fa){
dfn[u]=++num;
for(int i=head[u],v;i;i=nxt[i])
if((v=to[i])!=fa){
dep[v]=dep[u]+1; f[v][0]=u;
for(int j=1;j<=18;j++)
f[v][j]=f[f[v][j-1]][j-1];
dfs(v,u);
}
}
inline int LCA(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=18;i>=0;i--)
if(dep[f[x][i]]>=dep[y]) x=f[x][i];
if(x==y) return x;
for(int i=18;i>=0;i--)
if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
inline void init(){
for(int i=1;i<=sam.cnt;i++)
vec[i].clear(),vis[i]=0,ty1[i]=ty2[i]=0; tot=0;
}
void dfs2(int u,int fa){
siz1[u]=ty1[u],siz2[u]=ty2[u];
for(int i=0;i<vec[u].size();i++) dfs2(vec[u][i],u),siz1[u]+=siz1[vec[u][i]],siz2[u]+=siz2[vec[u][i]];
ans+=1LL*siz1[u]*siz2[u]*(sam.l[u]-sam.l[fa]);
}
inline LL solve(){
ans=0; int pt=tot,now=0;
for(int i=1;i<pt;i++){
int z=LCA(a[i].pos,a[i+1].pos);
if(!vis[z]) vis[z]=1,a[++tot]=qwq(z,dfn[z]);
}
sort(a+1,a+tot+1); qu[++now]=rt=a[1].pos;
for(int i=2;i<=tot;i++){
if(a[i].pos==a[i-1].pos) continue;
while(now>1 && LCA(qu[now],a[i].pos)!=qu[now]) now--;
vec[qu[now]].push_back(a[i].pos); qu[++now]=a[i].pos;
}
dfs2(rt,0); return ans;
}
int main(){
rd(n); rd(q);
sam.build(); dfs(1,0);
while(q--){
rd(sza),rd(szb); int x;
for(int i=1;i<=sza;i++){
rd(x); a[++tot]=qwq(id[x],dfn[id[x]]);
vis[id[x]]=1; ty1[id[x]]=1;
}
for(int i=1;i<=szb;i++){
rd(x); if(!vis[id[x]]) a[++tot]=qwq(id[x],dfn[id[x]]);
vis[id[x]]=1; ty2[id[x]]=1;
}
sort(a+1,a+tot+1);
printf("%I64d\n",solve());
if(q){
for(int i=1;i<=tot;i++)
x=a[i].pos,vec[x].clear(),vis[x]=0,ty1[x]=ty2[x]=0;
tot=0;
}
}
return 0;
}