题目:
https://www.luogu.org/problemnew/show/P4218
分析:
一种显然的暴力就是枚举一个起点,在这个点进行dfs,然后在后缀自动机上跟着跳。跳到的点的right集大小即为这条路径的答案。这样做的复杂度是
O
(
n
2
)
O(n^2)
O(n2)。
树上的路径问题可以考虑点分治。显然一条路径可以被拆成两段,
x
x
x到根和根到
y
y
y。
这条路径的答案就是所有
x
x
x到根路径的字符串在原串中匹配的右边界的集合,和根到
y
y
y路径字符串匹配的左边界集合的交。
对于每一个位置
i
i
i记录有多少个位置以这里结尾(另一段可以反过来做)。显然就是代表
[
1
,
i
]
[1,i]
[1,i]的字符串代表的节点fail树上到根的路径的和。我们可以把树上所有这样的路径的对应节点找到,再像预处理深度的方法计算出每一个叶子节点的答案。
总的答案就是每个位置的乘积的和再减去每个儿子自己与自己匹配的答案。
代码:
// luogu-judger-enable-o2
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <queue>
#define LL long long
const int maxn=1e5+7;
using namespace std;
int n,m,x,y,tot,num,root;
LL ans;
int ls[maxn],vis[maxn],f[maxn],size[maxn],b[maxn];
char a[maxn],s[maxn];
struct edge{
int y,next;
}g[maxn];
struct sam{
int cnt;
struct node{
int len,fail;
int son[26];
}t[maxn];
int tree[maxn][26];
int b[maxn],top[maxn],size[maxn],pos[maxn],id[maxn],tag[maxn],str[maxn];
void build_sam(char *s)
{
cnt=1;
int now=1,p,q,clone;
for (int i=1;i<=m;i++)
{
int c=s[i]-'a';
str[i]=c;
p=now;
now=++cnt;
t[now].len=t[p].len+1;
size[now]=1;
pos[now]=i;
id[i]=now;
while (p&&(!t[p].son[c])) t[p].son[c]=now,p=t[p].fail;
if (!p) t[now].fail=1;
else
{
q=t[p].son[c];
if (t[p].len+1==t[q].len) t[now].fail=q;
else
{
clone=++cnt;
t[clone]=t[q];
t[clone].len=t[p].len+1;
t[now].fail=t[q].fail=clone;
while (p&&(t[p].son[c]==q)) t[p].son[c]=clone,p=t[p].fail;
}
}
}
for (int i=1;i<=cnt;i++) b[t[i].len]++;
for (int i=1;i<=m;i++) b[i]+=b[i-1];
for (int i=1;i<=cnt;i++) top[b[t[i].len]--]=i;
for (int i=cnt;i>0;i--)
{
int x=top[i];
size[t[x].fail]+=size[x];
if (!pos[t[x].fail]) pos[t[x].fail]=pos[x];
tree[t[x].fail][str[pos[x]-t[t[x].fail].len]]=x;
}
}
void find(int x,int fa,int now)
{
now=t[now].son[a[x]-'a'];
if (!now) return;
ans+=(LL)size[now];
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if ((y==fa) || (vis[y])) continue;
find(y,x,now);
}
}
void dfs(int x,int fa,int now,int len)
{
if (len==t[now].len) now=tree[now][a[x]-'a'];
else if (str[pos[now]-len]!=a[x]-'a') now=0;
if (!now) return;
len++,tag[now]++;
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if ((y==fa) || (vis[y])) continue;
dfs(y,x,now,len);
}
}
void getsum()
{
for (int i=1;i<=cnt;i++)
{
int x=top[i];
tag[x]+=tag[t[x].fail];
}
}
}A,B;
void add(int x,int y)
{
g[++tot]=(edge){y,ls[x]};
ls[x]=tot;
}
void findroot(int x,int fa)
{
size[x]=1;
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if ((y==fa) || (vis[y])) continue;
findroot(y,x);
size[x]+=size[y];
f[x]=max(f[x],size[y]);
}
f[x]=max(f[x],num-size[x]);
if ((!root) || (f[x]<f[root])) root=x;
}
void getpoint(int x,int fa)
{
b[++tot]=x;
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if ((y==fa) || (vis[y])) continue;
getpoint(y,x);
}
}
void calc(int x,int fa,int op)
{
for (int i=1;i<=A.cnt;i++) A.tag[i]=0;
for (int i=1;i<=B.cnt;i++) B.tag[i]=0;
if (fa) A.dfs(x,fa,A.tree[1][a[fa]-'a'],1),B.dfs(x,fa,B.tree[1][a[fa]-'a'],1);
else A.dfs(x,0,1,0),B.dfs(x,0,1,0);
A.getsum(),B.getsum();
for (int i=1;i<=m;i++) ans+=(LL)op*(LL)A.tag[A.id[i]]*(LL)B.tag[B.id[m-i+1]];
}
void solve()
{
queue <int> q;
size[1]=n;
q.push(1);
int block=trunc(sqrt(n));
while (!q.empty())
{
int x=q.front();
q.pop();
num=size[x];
if (num>block)
{
root=0;
findroot(x,0);
x=root;
calc(x,0,1);
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if (vis[y]) continue;
calc(y,x,-1);
}
vis[x]=1;
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if (vis[y]) continue;
findroot(x,0);
q.push(y);
}
}
else
{
tot=0;
getpoint(x,0);
for (int i=1;i<=tot;i++) A.find(b[i],0,1);
}
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
scanf("%s",a+1);
scanf("%s",s+1);
A.build_sam(s);
reverse(s+1,s+m+1);
B.build_sam(s);
solve();
printf("%lld\n",ans);
}