传送门
给你一棵树,树上每个点有字符,询问所有树上路径形成的字符串在给定模式串中一共出现了几次。在模式串中不同位置出现要多次计算。
题解:
好题啊。
树上路径统计类的问题显然考虑的一般就是链分治和点分治。
而这种路径形成字符串的东西显然不是链分治可以搞的。
考虑点分治,假设当前分治中心为 u u u,则我们需要考虑 v → u → w v\rightarrow u\rightarrow w v→u→w的路径。
我们处理从分治中心向外的路径。接下来需要考虑拼接,考虑这个拼接的位置,发现实际上我们需要考虑的就是,对于每个前缀,它的后缀出现了多少,对于每个后缀,它的前缀出现了多少。然后枚举拼接位置计算答案。
建两棵后缀树即可。建后缀树可以用SAM,但是注意这时候匹配应该是在后缀树的压缩Trie上匹配,而不是用SAM的转移进行匹配。
然后SAM来一个下放标记即可算答案了。注意到这样每次会有一个 O ( m ) O(m) O(m)的时间复杂度,所以在分治子树过小的时候我们需要直接利用SAM暴力匹配求一个right集合的大小。
注意这个分治不仅是对于中心分治的情况,对于单独处理子树内部的情况也需要,不然会被蒲公英图给卡。
子树大小的阈值直接用 n \sqrt n n即可达到最优复杂度。
下面的代码由于用了大量的dfs,有大概 × 5 \times 5 ×5的常数,洛谷上开O2最慢的点两秒,全部改成BFS应该能卡进500ms,但是没什么意义,懒得管了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define re register
#define cs const
using std::cerr;
using std::cout;
using pii=std::pair<int,int>;
#define fi first
#define se second
cs int N=1e5+7;
cs int B=400;
int n,m;
ll ans;
struct SAM{
static cs int N=::N;
char s[N];int son[N][26],tr[N][26],fa[N],len[N],ps[N],now,last;
SAM(){now=last=1;}
inline void push_back(int c){
int cur=++now,p=last;
last=cur;len[cur]=len[p]+1;
for(;p&&!son[p][c];p=fa[p])son[p][c]=cur;
if(!p)fa[cur]=1;
else if(len[son[p][c]]==len[p]+1)fa[cur]=son[p][c];
else {
int nq=++now,q=son[p][c];ps[nq]=ps[q];
memcpy(son[nq],son[q],sizeof son[q]);
len[nq]=len[p]+1,fa[nq]=fa[q],fa[q]=fa[cur]=nq;
for(;p&&son[p][c]==q;p=fa[p])son[p][c]=nq;
}
}
std::vector<int> e[N];
int bin[N],nd[N],siz[N];
inline void init(){
for(int re i=1;i<=m;++i)push_back(s[i]),ps[last]=i;
for(int re i=2;i<=now;++i)e[fa[i]].push_back(i);
for(int re i=1;i<=now;++i)++bin[len[i]];
for(int re i=1;i<=m;++i)bin[i]+=bin[i-1];
for(int re i=now;i;--i)nd[bin[len[i]]--]=i;
for(int re i=now;i;--i){
int u=nd[i];
if(ps[u]==len[u])++siz[u];
siz[fa[u]]+=siz[u];
}
for(int re u=1;u<=now;++u)
for(int re v:e[u])tr[u][s[ps[v]-len[u]]]=v;
memset(bin,0,sizeof bin);
}
inline pii trans(pii t,int c)cs{
if(!t.fi)return pii(0,0);
if(t.se==len[t.fi])return pii(tr[t.fi][c],t.se+1);
if(s[ps[t.fi]-t.se]==c)return pii(t.fi,t.se+1);
return pii(0,0);
}
inline void add(pii t,int c){++bin[trans(t,c).fi];}
inline void add(pii t){++bin[t.fi];}
}S,T;
SAM *t;
int c1[N],c2[N];
void dfs_sam(int u,int sum,int flag){
sum+=t->bin[u];
if(t->ps[u]==t->len[u])(flag?c2[m-t->ps[u]+1]:c1[t->ps[u]])+=sum;
for(int re v:t->e[u])dfs_sam(v,sum,flag);
}
inline ll calc(){
ll res=0;
memset(c1+1,0,sizeof(int)*m);
memset(c2+1,0,sizeof(int)*m);
t=&S;dfs_sam(1,0,0);t=&T;dfs_sam(1,0,1);
for(int re i=1;i<=m;++i)res+=(ll)c1[i]*c2[i];
memset(S.bin+1,0,sizeof(int)*S.now);
memset(T.bin+1,0,sizeof(int)*T.now);
return res;
}
char s[N];
std::vector<int> G[N];
int fa[N],siz[N],ban[N];
int maxn,total,Gr;
void get_siz(int u,int p){
siz[u]=1;for(int re v:G[u])
if(v!=p&&!ban[v])get_siz(v,u),siz[u]+=siz[v];
}
void find_G(int u,int p){
int mx=total-siz[u];for(int re v:G[u])
if(v!=p&&!ban[v])find_G(v,u),mx=std::max(mx,siz[v]);
if(mx<maxn){maxn=mx,Gr=u;}
}
inline void get_G(int u){
get_siz(u,0);
total=siz[u],maxn=1e9;
find_G(u,0);
}
namespace Force{
int rt,flag;
void dfs2(int u,int p,int t){
if(!t)return ;
ans+=flag*S.siz[t];
for(int re v:G[u])if(v!=p&&!ban[v])
dfs2(v,u,S.son[t][s[v]]);
}
void dfs1(int u,int p){
int t=1;
if(flag<0){
for(int re v=u;;v=fa[v]){t=S.son[t][s[v]];if(v==rt||!t)break;}
t=S.son[t][s[Gr]];
dfs2(rt,0,S.son[t][s[rt]]);
}
else dfs2(u,0,S.son[t][s[u]]);
for(int re v:G[u])if(v!=p&&!ban[v])dfs1(v,u);
}
inline void solve(int u,int typ){
rt=u;flag=typ;dfs1(u,0);
}
}
void dfs(int u,int p,pii t1,pii t2){
S.add(t1=S.trans(t1,s[u])),T.add(t2=T.trans(t2,s[u]));
if(!t1.fi&&!t2.fi)return ;
for(int re v:G[u])if(v!=p&&!ban[v])dfs(v,u,t1,t2);
}
inline void calc_sub(int u,int sz,pii p1,pii p2){
if(sz<=B)return Force::solve(u,-1);
dfs(u,0,p1,p2);ans-=calc();
}
void pre_dfs(int u,int p){
fa[u]=p;for(int re v:G[u])if(v!=p&&!ban[v])pre_dfs(v,u);
}
inline void solve_G(int u){
if(total<=B)return Force::solve(u,1);
ban[u]=true;
pii p(1,0),p1=S.trans(p,s[u]),p2=T.trans(p,s[u]);
S.add(p,s[u]);T.add(p,s[u]);
for(int re v:G[u])if(!ban[v]){
pre_dfs(v,u);
dfs(v,u,p1,p2);
}
ans+=calc();
for(int re v:G[u])if(!ban[v])calc_sub(v,siz[v]>siz[u]?total-siz[u]:siz[v],p1,p2);
for(int re v:G[u])if(!ban[v]){
get_G(v);
solve_G(Gr);
}
}
signed main(){
#ifdef zxyoi
freopen("jewelry.in","r",stdin);
#endif
scanf("%d%d",&n,&m);
for(int re i=1;i<n;++i){
int u,v;scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
scanf("%s",s+1);for(int re i=1;i<=n;++i)s[i]-='a';
scanf("%s",S.s+1);for(int re i=1;i<=m;++i)S.s[i]-='a';
std::reverse_copy(S.s+1,S.s+m+1,T.s+1);
S.init();T.init();
get_G(1);solve_G(Gr);
cout<<ans<<"\n";
return 0;
}