题意
有一棵n个节点的树和一个长度为m的字符串S,树上每个节点有一个字符。问对于任意的有序数对(x,y),从x到y路径组成的字符串在S中出现次数的和。
n,m<=50000
n
,
m
<=
50000
分析
因为是求出现次数,显然可以用到后缀自动机来做。
首先考虑两种不同的暴力做法:
暴力1:枚举每个点作为起点,然后把整棵树dfs一次,求出起点到每个点组成的路径的出现次数。由于sam的转移是 O(1) O ( 1 ) 的,所以这么做总的复杂度是 O(n2) O ( n 2 ) 。
暴力2:我们考虑求每个点作为路径的lca时候的贡献。设路径的lca为点Z,那么对于一条路径(X,Y),我们可以将其拆成(X,Z)和(Z,Y)两条路径。对于Z的字符在S中的匹配位置,我们可以求出有多少条形如(Z,Y)的路径在该位置匹配;同理我们可以求出有多少条形如(X,Z)的路径在S的反串对应位置上匹配,然后每一位两边的匹配数相乘的和就是这个点的贡献。但注意到有可能X和Y位于同一棵子树内,所以还要对每棵子树再求一次来去重。关于如何实现求每个位置的匹配路径数量,我们可以在正反两串的后缀树上打标记,最后下推到叶节点就好了。这么做因为求每个点贡献的时候都要把后缀树扫一遍,所以总的复杂度是 O(n2+nm) O ( n 2 + n m ) 。
再仔细思考一下不难发现暴力2可以通过点分治来优化。因为点分治后,所有分治子树的size和大小为
O(nlogn)
O
(
n
l
o
g
n
)
,所以总的复杂度就是
O(nlogn+nm)
O
(
n
l
o
g
n
+
n
m
)
。
那么现在我们的瓶颈就在于每次扫后缀树时的
O(m)
O
(
m
)
。
当当前的分治子树size较小时,我们暴力扫后缀树显然是一种浪费。那么我们可以怎么做呢?做法1!
我们不妨设一个阈值B,当分治子树的size不超过B时我们用做法1,不然就用做法2。
显然size超过B的分治子树只有不超过
O(⌊nB⌋)
O
(
⌊
n
B
⌋
)
个,若我们碰到一个size不超过B的子树就退出的话,也可以证明遍历到的size不超过B的子树只有
O(⌊nB⌋)
O
(
⌊
n
B
⌋
)
个。
解一下方程发现,当B是
m−−√
m
的时候时间复杂度取到最优。
这样总的复杂度就是
O((n+m)m−−√)
O
(
(
n
+
m
)
m
)
,就可以AC啦。
接下来讲一下如何在后缀树上打标记。
一开始我以为的是先在sam的DAG上跑,跑到最后一个点时就在该点打标记,然后最后把标记下传到叶节点。打完之后发现无论怎样都对不齐,仔细思考了一下发现这样是不对的。
假设我们建出了S的后缀树,当我们要加入字符串(Z,Y)的时候,因为后缀树本质是一棵所有后缀组成的trie,于是我们可以从后缀树的根开始往下跳。跳到最终点的时候,我们就在这里打一个标记,然后最后再把所有标记下传到叶节点,那么我们就可以知道每个后缀的前缀匹配了多少字符串。
我们又知道S的后缀树等于其反串的后缀自动机上的parents树,那么我们就可以先把反串的parents树搞出来,然后在parents树上用同样的方法来打标记即可。
换句话说,用sam的时候,当我们要在字符串的后面加入一个字符时,可以沿着sam的DAG跳,而当我们要在字符串的前面加入一个字符时,我们就可以沿着parents树往下跳。
这题其实还有一个优化,就是在用方法2去重的时候,也应该按照子树的大小来决定用哪一种方法。而我比较懒,所以就直接用了方法2来去做。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
const int N=50005;
int n,m,cnt,last[N],c[N],sum,root,f[N],size[N],id1[N],id2[N],tot,a[N],B;
struct edge{int to,next;}e[N*2];
char str[N];
LL ans;
bool vis[N];
void addedge(int u,int v)
{
e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}
struct SAM
{
int ls,sz,ch[N*2][26],mx[N*2],fa[N*2],b[N*2],a[N*2],tag[N*2],size[N*2],s[N],rig[N*2],son[N*2][26];
int extend(int x,int pos)
{
int p,q,np,nq;
p=ls;ls=np=++sz;mx[np]=mx[p]+1;size[np]++;rig[np]=pos;
for (;p&&!ch[p][x];p=fa[p]) ch[p][x]=np;
if (!p) fa[np]=1;
else
{
q=ch[p][x];
if (mx[q]==mx[p]+1) fa[np]=q;
else
{
nq=++sz;mx[nq]=mx[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
fa[nq]=fa[q];fa[q]=fa[np]=nq;
for (;ch[p][x]==q;p=fa[p]) ch[p][x]=nq;
}
}
return ls;
}
void build()
{
for (int i=1;i<=sz;i++) b[mx[i]]++;
for (int i=1;i<=sz;i++) b[i]+=b[i-1];
for (int i=sz;i>=1;i--) a[b[mx[i]]--]=i;
for (int i=sz;i>=2;i--)
{
int x=a[i],F=fa[x];
size[F]+=size[x];
rig[F]=!rig[F]?rig[x]:rig[F];
son[F][s[rig[x]-mx[F]]]=x;
}
}
void mark(int x,int fa,int now,int len)
{
if (len==mx[now]) now=son[now][c[x]];
else if (s[rig[now]-len]!=c[x]) now=0;
if (!now) return;
tag[now]++;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa&&!vis[e[i].to]) mark(e[i].to,x,now,len+1);
}
void push()
{
for (int i=1;i<=sz;i++) tag[a[i]]+=tag[fa[a[i]]];
}
}sam1,sam2;
void get_root(int x,int fa)
{
size[x]=1;f[x]=0;
for (int i=last[x];i;i=e[i].next)
{
if (e[i].to==fa||vis[e[i].to]) continue;
get_root(e[i].to,x);
size[x]+=size[e[i].to];
f[x]=max(f[x],size[e[i].to]);
}
f[x]=max(f[x],sum-size[x]);
if (!root||f[x]<f[root]) root=x;
}
void get_sum(int x,int fa)
{
sum++;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa&&!vis[e[i].to]) get_sum(e[i].to,x);
}
void get(int x,int fa)
{
a[++tot]=x;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa&&!vis[e[i].to]) get(e[i].to,x);
}
void dfs(int x,int fa,int now)
{
now=sam1.ch[now][c[x]];
if (!now) return;
ans+=sam1.size[now];
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa&&!vis[e[i].to]) dfs(e[i].to,x,now);
}
void work(int x,int fa,int f)
{
for (int i=1;i<=sam1.sz;i++) sam1.tag[i]=0;
for (int i=1;i<=sam2.sz;i++) sam2.tag[i]=0;
int to=c[fa];
if (fa) sam1.mark(x,fa,sam1.son[1][to],1),sam2.mark(x,fa,sam2.son[1][to],1);
else sam1.mark(x,fa,1,0),sam2.mark(x,fa,1,0);
sam1.push();sam2.push();
for (int i=0;i<m;i++) ans+=(LL)f*sam1.tag[id1[i]]*sam2.tag[id2[m-1-i]];
}
void solve(int x)
{
if (sum<=B)
{
tot=0;get(x,0);
for (int i=1;i<=tot;i++) dfs(a[i],0,1);
for (int i=1;i<=tot;i++) vis[a[i]]=0;
return;
}
vis[x]=1;
work(x,0,1);
for (int i=last[x];i;i=e[i].next)
if (!vis[e[i].to]) work(e[i].to,x,-1);
for (int i=last[x];i;i=e[i].next)
{
if (vis[e[i].to]) continue;
sum=0;get_sum(e[i].to,x);
root=0;
get_root(e[i].to,x);
solve(root);
}
}
int main()
{
scanf("%d%d",&n,&m);B=sqrt(n);
for (int i=1;i<n;i++)
{
int x,y;scanf("%d%d",&x,&y);
addedge(x,y);
}
scanf("%s",str+1);
for (int i=1;i<=n;i++) c[i]=str[i]-'a';
scanf("%s",str);
sam1.sz=sam1.ls=sam2.sz=sam2.ls=1;
for (int i=0;i<m;i++) sam1.s[i]=str[i]-'a',id1[i]=sam1.extend(sam1.s[i],i);
reverse(str,str+m);
for (int i=0;i<m;i++) sam2.s[i]=str[i]-'a',id2[i]=sam2.extend(sam2.s[i],i);
sam1.build();sam2.build();
root=0;sum=n;
get_root(1,0);
solve(root);
printf("%lld",ans);
return 0;
}