[CTSC2010]珠宝商
洛谷题目传送门
简要题意
给定一颗
n
n
n个节点的树,和一个长度为
m
m
m的模式串
S
S
S
树上每个节点都有一个字符
求树上所有路径的点的字符拼成的字符串在
S
S
S中的出现次数之和
解题思路
路径统计?一听就很点分治
字串出现次数?一听就很SAM
那这个题实际也就是这两个的结合了
I
首先有一个显然的
O
(
n
2
)
O(n^2)
O(n2)做法
建出
S
A
M
SAM
SAM
并求出每个节点代表的字串在
S
S
S中出现的次数,即为
s
i
z
[
x
]
siz[x]
siz[x]
那么我们直接枚举路径起点,然后跑
d
f
s
dfs
dfs,并同时在
S
A
M
SAM
SAM上跑
跑到一个节点就统计这个节点的
s
i
z
siz
siz就行了
II
考虑点分治
求出树的重心
r
r
r,然后求出经过重心的字符串的答案
因为设路径是
s
→
t
s\to t
s→t
那么可以把路径拆成两部分,
s
→
r
s\to r
s→r 和
r
→
t
r\to t
r→t
我们考虑枚举
S
S
S中的位置
计算有多少
s
→
r
s\to r
s→r的路径是以当前位置结束的
计算有短少
r
→
t
r\to t
r→t的路径是从当前位置开始的
乘起来就是:经过当前位置,且在当前位置时路径上正好是
r
r
r的路径条数
加起来就是答案了
观察到两个计算是对称的
我们只需要建出
S
S
S的反串的后缀自动机,然后在两个SAM上的操作就一模一样了
问题变化为求有多少
s
→
r
s\to r
s→r的路径经过某个
S
A
M
SAM
SAM中的节点
一个显然的思路是
我们从
r
r
r开始向下
d
f
s
dfs
dfs,并维护当前路径对应的是
S
A
M
SAM
SAM中的哪个节点
然后把节点权值加一
那么出现在一个节点中的路径条数就是parent树上,当前点到根节点的权值之和
现在还有一个问题
注意到我们必须在路径前加字符,然后在SAM上跑
思考怎么维护这个
设当前路径长度为
T
T
T,在
S
A
M
SAM
SAM上的节点是
x
x
x,某个
x
x
x出现的位置是
R
x
R_x
Rx,新加入个字符时
c
c
c
那么如果
T
<
l
e
n
[
x
]
T<len[x]
T<len[x]那么加字符如果可以对应某个节点,那么一定还是
x
x
x,直接判断
S
[
R
[
x
]
−
T
]
S[R[x]-T]
S[R[x]−T]是不是c就行了
T
=
l
e
n
[
x
]
T=len[x]
T=len[x]
我们需要找一个节点
y
y
y使得
f
a
[
y
]
=
x
fa[y]=x
fa[y]=x,且
S
[
R
[
y
]
−
l
e
n
[
f
a
[
y
]
]
]
=
c
S[R[y]-len[fa[y]]]=c
S[R[y]−len[fa[y]]]=c
可以预处理出来
别忘了点分治时同一个子树要容斥掉
我们分析一波复杂度
因为每次统计答案必须遍历
S
S
S一遍
所以复杂度
O
(
n
log
n
+
n
m
)
O(n\log n+nm)
O(nlogn+nm)
似乎更糟了
因此我们还有算法
I
I
I
III
III
III
设当前分治子树的大小为
W
W
W
来一波根号分治
如果
W
≥
n
W\geq \sqrt n
W≥n按照算法
I
I
II
II统计就可以了
如果
W
<
n
W<\sqrt n
W<n,按照
n
2
n^2
n2的做法做就行了
总复杂度
O
(
n
n
+
m
n
)
O(n\sqrt n+m\sqrt n)
O(nn+mn)
还有一个细节
就是容斥的时候如果子节点的子树小于
n
\sqrt n
n
也需要类似
O
(
n
2
)
O(n^2)
O(n2)的做法去做
否则会被菊花图卡到
O
(
n
m
)
O(nm)
O(nm)
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5+7;
int read()
{
int X=0;bool flag=0;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')flag=1;ch=getchar();}
while(ch>='0'&&ch<='9'){X=(X<<3)+(X<<1)+ch-'0';ch=getchar();}
if(flag) return ~(X-1);
return X;
}
struct edge
{
int y,next;
}e[N];
int link[N],t=0;
bool vis[N];
void add(int x,int y)
{
e[++t].y=y;
e[t].next=link[x];
link[x]=t;
}
int n,m,S;
char str[N];
int a[N];
int pos1[N],pos2[N];
struct SAM
{
int len[N],tr[N][26],fa[N];
int pre[N][26];
int tot,last,siz[N];
int fst[N],cnt[N];
SAM()
{
tot=last=1;
}
inline void copy(int x,int y)
{
fa[x]=fa[y];
len[x]=len[y];
for(int c=0;c<26;c++)
tr[x][c]=tr[y][c];
}
inline void Extend(int c,int x)
{
int p=last,np=last=++tot;
len[np]=len[p]+1;
siz[np]=1;
fst[np]=x;
while(p&&!tr[p][c])
{
tr[p][c]=np;
p=fa[p];
}
if(!p) fa[np]=1;
else
{
int q=tr[p][c];
if(len[q]==len[p]+1) fa[np]=q;
else
{
int nq=++tot;
copy(nq,q);
len[nq]=len[p]+1;
fa[np]=fa[q]=nq;
while(p&&tr[p][c]==q)
{
tr[p][c]=nq;
p=fa[p];
}
}
}
}
int c[N],SA[N];
int s[N];
inline void Build()
{
for(int i=1;i<=tot;i++) c[len[i]]++;
for(int i=1;i<=tot;i++) c[i]+=c[i-1];
for(int i=1;i<=tot;i++) SA[c[len[i]]--]=i;
for(int i=tot;i>=2;i--)
{
int x=SA[i];
siz[fa[x]]+=siz[x];
fst[fa[x]]=fst[x];
pre[fa[x]][s[fst[x]-len[fa[x]]]]=x;
}
}
inline void clear()
{
for(int i=1;i<=tot;i++)
cnt[i]=0;
}
inline void spread()
{
for(int i=2;i<=tot;i++)
cnt[SA[i]]+=cnt[fa[SA[i]]];
}
inline void Addin(int x,int fat,int p,int T)
{
if(len[p]==T) p=pre[p][a[x]];
else if(s[fst[p]-T]!=a[x]) p=0;
if(!p) return;
// cout<<x<<' '<<p<<endl;
cnt[p]++;
for(int i=link[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==fat||vis[y]) continue;
Addin(y,x,p,T+1);
}
}
void Put()
{
for(int i=1;i<=tot;i++)
for(int c=0;c<26;c++)
if(tr[i][c]) cout<<i<<' '<<tr[i][c]<<' '<<c<<endl;
}
void push()
{
for(int i=2;i<=tot;i++)
cout<<fa[i]<<' '<<i<<endl;
}
}A,B;
int s[N];
int siz[N],son[N];
void getroot(int x,int pre,int &root,int tot)
{
siz[x]=1;
son[x]=0;
for(int i=link[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==pre||vis[y]) continue;
getroot(y,x,root,tot);
siz[x]+=siz[y];
son[x]=max(son[x],siz[y]);
}
son[x]=max(son[x],tot-siz[x]);
if(!root||son[root]>son[x]) root=x;
}
LL ans=0;
void getsiz(int x,int pre)
{
siz[x]=1;
for(int i=link[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==pre||vis[y]) continue;
getsiz(y,x);
siz[x]+=siz[y];
}
}
void getans(int x,int pre,int opt)
{
A.clear();
B.clear();
if(pre)
{
A.Addin(x,0,A.tr[1][a[pre]],1);
B.Addin(x,0,B.tr[1][a[pre]],1);
}
else
{
A.Addin(x,0,1,0);
B.Addin(x,0,1,0);
}
A.spread();
B.spread();
for(int i=1;i<=m;i++)
ans+=opt*A.cnt[pos1[i]]*B.cnt[pos2[m-i+1]];
}
int q[N],top=0;
int par[N];
void getpoint(int x,int pre)
{
q[++top]=x;
for(int i=link[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==pre||vis[y]) continue;
par[y]=x;
getpoint(y,x);
}
}
int ban=0;
void getpath(int x,int pre,int p,int opt)
{
p=A.tr[p][a[x]];
if(!p) return;
ans+=opt*A.siz[p];
// cout<<x<<' '<<A.siz[p]<<endl;
for(int i=link[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==pre||vis[y]||y==ban) continue;
getpath(y,x,p,opt);
}
}
void Blocks(int x)
{
top=0;
getpoint(x,0);
for(int i=1;i<=top;i++)
getpath(q[i],0,1,1);
}
void getpart(int x,int pre)
{
top=0;
int rt=x;
par[rt]=pre;
getpoint(x,pre);
for(int i=1;i<=top;i++)
{
int p=1,cur=q[i];
while(cur!=pre)
{
p=A.tr[p][a[cur]];
cur=par[cur];
}
p=A.tr[p][a[cur]];
getpath(x,0,p,-1);
}
}
void Divide(int x,int tot)
{
if(tot<=S)
{
Blocks(x);
return;
}
vis[x]=1;
getans(x,0,1);
// cout<<x<<"----"<<ans<<endl;
for(int i=link[x];i;i=e[i].next)
{
int y=e[i].y;
if(vis[y]) continue;
int z=0,T=siz[y];
if(siz[y]>S)
getans(y,x,-1);
else getpart(y,x);
// cout<<y<<' '<<ans<<endl;
getroot(y,x,z,siz[y]);
Divide(z,T);
}
}
int main()
{
n=read();
m=read();
S=sqrt(m);
for(int i=1;i<n;i++)
{
int x=read(),y=read();
add(x,y);
add(y,x);
}
scanf("%s",str+1);
for(int i=1;i<=n;i++)
a[i]=str[i]-'a';
scanf("%s",str+1);
for(int i=1;i<=m;i++)
s[i]=str[i]-'a';
for(int i=1;i<=m;i++)
{
A.Extend(s[i],i);
pos1[i]=A.last;
A.s[i]=s[i];
}
reverse(s+1,s+m+1);
for(int i=1;i<=m;i++)
{
B.Extend(s[i],i);
pos2[i]=B.last;
B.s[i]=s[i];
}
A.Build();
B.Build();
int root=0;
getroot(1,0,root,n);
Divide(root,n);
cout<<ans;
return 0;
}