Description
有一棵n个节点的树和一个长度为m的字符串S,树上每个节点有一个字符。问对于任意的有序数对(x,y),从x到y路径组成的字符串在S中出现次数的和。
n
,
m
≤
5
⋅
1
0
4
n,m\le5\cdot10^4
n,m≤5⋅104
Solution
很显然要在SAM上跑,我们要求的是所有路径组成的串在SAM上size的和,那么一个显然的响法就是我们点分治然后合并跨越不同子树的路径
记当前分治重心为rt,那么我们要求的就是若干条(x,rt)和(rt,y)这样路径的并组成的串在SAM上的size
这个可以正反建两个SAM,然后用SAM求出后缀树,那么所有(x,rt)都可以看作是c[rt]加上若干个前缀字符得到的新串,并且恰好就是它后缀树上到根的路径,那么我们对每个节点打出现次数+1的标记然后O(m)暴力下传就可以做到O(nlogn+nm)的复杂度了
我们发现这样做还是不行,考虑平衡规划一下。一个暴力就是我们枚举一条路径的起点,然后暴力dfs并顺便在SAM上求size。这样直接做是O(n^2)的。对于 > n >\sqrt n >n的子树我们用点分治搞,其余的子树我们直接上这个暴力。可以证明一棵树大小超过根号的子树不超过根号个,小于根号的我们一遇到就退掉也只会遍历根号个,那么总的复杂度就是十分科学的了
Code
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
#define rep(i,st,ed) for (register int i=st;i<=ed;++i)
#define drp(i,st,ed) for (register int i=st;i>=ed;--i)
#define copy(x,t) memcpy(x,t,sizeof(x))
typedef long long LL;
const int INF=0x3f3f3f3f;
const int N=100005;
struct edge {int y,next;} e[N*2];
int ls[N],v[N],edCnt,rt,sum;
int size[N],mx[N],n,m,B;
char c[N],str[N]; LL ans;
bool del[N];
struct SAM {
int s[N],id[N],a[N],b[N],rig[N],size[N],tot,last;
int rec[N][26],son[N][26],tag[N],fa[N],mx[N];
int extend(int ch,int pos) {
int p,q,np,nq;
p=last; np=last=++tot; mx[np]=mx[p]+1;
rig[np]=pos; size[np]++;
for (;p&&!rec[p][ch];p=fa[p]) rec[p][ch]=np;
if (!p) fa[np]=1;
else {
q=rec[p][ch];
if (mx[q]==mx[p]+1) fa[np]=q;
else {
nq=++tot; mx[nq]=mx[p]+1;
copy(rec[nq],rec[q]);
fa[nq]=fa[q];
fa[np]=fa[q]=nq;
for (;p&&rec[p][ch]==q;p=fa[p]) rec[p][ch]=nq;
}
}
return last;
}
void build() {
tot=last=1;
rep(i,1,m) id[i]=extend(s[i],i);
rep(i,1,tot) b[mx[i]]++;
rep(i,1,tot) b[i]+=b[i-1];
drp(i,tot,1) a[b[mx[i]]--]=i;
drp(i,tot,2) {
int x=a[i];
if (!rig[fa[x]]) rig[fa[x]]=rig[x];
son[fa[x]][s[rig[x]-mx[fa[x]]]]=x;
size[fa[x]]+=size[x];
}
}
void col(int x,int fa,int now,int len) {
if (len==mx[now]) now=son[now][c[x]];
else if (s[rig[now]-len]!=c[x]) now=0;
if (!now) return ;
tag[now]++;
for (int i=ls[x];i;i=e[i].next) {
if (e[i].y==fa||del[e[i].y]) continue;
col(e[i].y,x,now,len+1);
}
}
void push() {
rep(i,1,tot) tag[a[i]]+=tag[fa[a[i]]];
}
void clear() {
rep(i,1,tot) tag[i]=0;
}
} SAM1,SAM2;
void add_edge(int x,int y) {
e[++edCnt]=(edge) {y,ls[x]}; ls[x]=edCnt;
e[++edCnt]=(edge) {x,ls[y]}; ls[y]=edCnt;
}
void get_size(int x,int fa) {
size[x]=1; mx[x]=0;
for (int i=ls[x];i;i=e[i].next) {
if (e[i].y==fa||del[e[i].y]) continue;
get_size(e[i].y,x); size[x]+=size[e[i].y];
mx[x]=std:: max(mx[x],size[e[i].y]);
}
mx[x]=std:: max(mx[x],sum-size[x]);
if (mx[x]<mx[rt]) rt=x;
}
void dfs(int x,int fa,int now) {
now=SAM1.rec[now][c[x]];
if (!now) return ;
ans+=SAM1.size[now];
for (int i=ls[x];i;i=e[i].next) {
if (e[i].y==fa||del[e[i].y]) continue;
dfs(e[i].y,x,now);
}
}
void get(int x,int fa) {
v[++v[0]]=x;
for (int i=ls[x];i;i=e[i].next) {
if (e[i].y==fa||del[e[i].y]) continue;
get(e[i].y,x);
}
}
void change(int x,int fa,int xs) {
SAM1.clear(); SAM2.clear();
if (fa) {
SAM1.col(x,fa,SAM1.son[1][c[fa]],1);
SAM2.col(x,fa,SAM2.son[1][c[fa]],1);
} else {
SAM1.col(x,0,1,0);
SAM2.col(x,0,1,0);
}
SAM1.push();
SAM2.push();
rep(i,1,m) ans+=SAM1.tag[SAM1.id[i]]*SAM2.tag[SAM2.id[m-i+1]]*xs;
}
void solve(int x) {
if (size[x]<=B) {
v[0]=0; get(x,0);
rep(i,1,v[0]) dfs(v[i],0,1);
return ;
}
del[x]=1;
change(x,0,1);
for (int i=ls[x];i;i=e[i].next) {
if (del[e[i].y]) continue;
change(e[i].y,x,-1);
}
for (int i=ls[x];i;i=e[i].next) {
if (del[e[i].y]) continue;
sum=size[e[i].y]; rt=0;
get_size(e[i].y,x);
solve(rt);
}
}
int main(void) {
scanf("%d%d",&n,&m); B=sqrt(n);
rep(i,2,n) {
int x,y; scanf("%d%d",&x,&y);
add_edge(x,y);
}
scanf("%s",c+1);
rep(i,1,n) c[i]-='a';
scanf("%s",str+1);
rep(i,1,m) SAM1.s[i]=str[i]-'a',SAM2.s[m-i+1]=str[i]-'a';
SAM1.build(),SAM2.build();
mx[rt=0]=INF; sum=n;
get_size(1,0);
solve(rt);
printf("%lld\n", ans);
return 0;
}