题目描述
题解
对询问串建立 AC \text{AC} AC 自动机,考虑建出 fail \text{fail} fail 树, fail \text{fail} fail 树上节点所代表的串是这个节点子树内每个点所代表的的串的后缀。所以我们可以把链分成两条,把正反串都放入 AC \text{AC} AC 自动机中,对于一条链 ( l c a , u ) (lca,u) (lca,u) ,对于不包含 l c a lca lca 的子串,我们可以用根到 u u u 的答案减去根到包含 l c a lca lca 的子串的最上方的点的答案,那我们就可以记录一下询问串的结束节点, dfs \text{dfs} dfs 原树的时候也一起走 AC \text{AC} AC 自动机,进入的时候 + 1 +1 +1 ,回溯的时候 − 1 -1 −1 ,用树状数组维护区间和即可。然后如果子串包含了 l c a lca lca 的话发现这条路径上有效的点是 2 ∣ s ∣ 2|s| 2∣s∣ 的,于是拉出来做 kmp \text{kmp} kmp 即可。
代码
#include <bits/stdc++.h>
using namespace std;const int N=3e5+5;
int n,m,id[N],sz[N],dp[N],fa[19][N],hd[N],V[N];
int nx[N],tt,ne[N],a[N],su[N],tr[N][26],fi[N];
char t[N],s[N],up[N],W[N];vector<int>e[N];
struct O{int i,u,v;};vector<O>p[N];queue<int>Q;
void add(int u,int v,char c){
nx[++tt]=hd[u];V[hd[u]=tt]=v;W[tt]=c;
}
void dfs(int u,int fr){
dp[u]=dp[fa[0][u]=fr]+1;
for (int i=1;fa[i-1][fa[i-1][u]];i++)
fa[i][u]=fa[i-1][fa[i-1][u]];
for (int v,i=hd[u];i;i=nx[i])
if ((v=V[i])!=fr) up[v]=W[i],dfs(v,u);
}
int kmp(int n,int m){
if (n<m) return 0;
ne[0]=ne[1]=0;
for (int j,i=1;i<m;i++){
j=ne[i];
while(j && t[j]!=t[i]) j=ne[j];
if (t[j]==t[i]) ne[i+1]=j+1;
else ne[i+1]=0;
}
int j=0,v=0;
for (int i=0;i<n;i++){
while(j && s[i]!=t[j]) j=ne[j];
if (s[i]==t[j]) j++;
if (j==m) v++;
}
return v;
}
int ins(int m){
int v=0;
for (int i=0,j;i<m;i++){
j=t[i]-97;
if (!tr[v][j])
tr[v][j]=++tt;
v=tr[v][j];
}
return v;
}
void build(){
for (int i=0;i<26;i++)
if (tr[0][i]) Q.push(tr[0][i]);
while(!Q.empty()){
int u=Q.front();Q.pop();
for (int v,i=0;i<26;i++){
v=tr[u][i];
if (v) fi[v]=tr[fi[u]][i],Q.push(v);
else tr[u][i]=tr[fi[u]][i];
}
}
for (int i=1;i<=tt;i++) e[fi[i]].push_back(i);
}
void dfs(int u){
id[u]=++tt;sz[u]=1;
int z=e[u].size();
for (int v,i=0;i<z;i++)
v=e[u][i],dfs(v),sz[u]+=sz[v];
}
int lca(int u,int v){
if (dp[u]<dp[v]) swap(u,v);
for (int i=17;~i;i--)
if (dp[fa[i][u]]>=dp[v]) u=fa[i][u];
if (u==v) return u;
for (int i=17;~i;i--)
if (fa[i][u]!=fa[i][v])
u=fa[i][u],v=fa[i][v];
return fa[0][u];
}
void upd(int x,int v){
x=id[x];
for (;x<=tt;x+=x&-x) su[x]+=v;
}
int qry(int x){
int l=id[x]-1,r=id[x]+sz[x]-1,v=0;
for (;r;r-=r&-r) v+=su[r];
for (;l;l-=l&-l) v-=su[l];
return v;
}
void dfs(int u,int fr,int k){
upd(k,1);int z=p[u].size();
for (int i=0;i<z;i++)
a[p[u][i].i]+=p[u][i].v*qry(p[u][i].u);
for (int v,i=hd[u];i;i=nx[i])
if ((v=V[i])!=fr) dfs(v,u,tr[k][W[i]-97]);
upd(k,-1);
}
int Up(int u,int x){
if (x<0) return u;
for (int i=17;~i;i--)
if (x&(1<<i)) u=fa[i][u];
return u;
}
int main(){
cin>>n>>m;
for (int u,v,i=1;i<n;i++)
scanf("%d%d%s",&u,&v,t),
add(u,v,t[0]),add(v,u,t[0]);
dfs(1,0);tt=0;
for (int i=1,u,v,len,p1,p2,w,u1,u2,z;i<=m;i++){
scanf("%d%d%s",&u,&v,t);z=lca(u,v);
len=strlen(t);p1=ins(len);
reverse(t,t+len);p2=ins(len);w=0;
u1=u2=Up(u,dp[u]-dp[z]-len+1);
for (int j=1;j<=dp[u1]-dp[z];j++)
s[w++]=up[u2],u2=fa[0][u2];
if (dp[u]-dp[z]>=len)
p[u1].push_back((O){i,p2,-1}),
p[u].push_back((O){i,p2,1});
u1=u2=Up(v,dp[v]-dp[z]-len+1);
w+=dp[u1]-dp[z];
for (int j=1;j<=dp[u1]-dp[z];j++)
s[w-j]=up[u2],u2=fa[0][u2];
if (dp[v]-dp[z]>=len)
p[u1].push_back((O){i,p1,-1}),
p[v].push_back((O){i,p1,1});
reverse(t,t+len);a[i]+=kmp(w,len);
}
build();tt=0;dfs(0);dfs(1,0,0);
for (int i=1;i<=m;i++) printf("%d\n",a[i]);
return 0;
}