题意:
给一棵树,每个点上有一个字母,问对于所有(x,y),求x到y的路径所组成的字符串在S中出现次数的和。
题解:
先上题解:begined
CTSC2010 珠宝商新解
然后说说个人的垃圾理解。
首先考虑暴力,一种显然的做法就是从每个点开始dfs整棵树,顺便在sam上走,那么每次加上right集合的大小即可。
然后考虑一种没那么显然的暴力。
枚举lca,然后将路径拆成,(x->z)(z->y),那么假如可以求出在正串上的每个位置能匹配多少个(z->y)的串,然后再用反串求一次(z->x)的值,就可以将对应位置相乘算出过着的点的串的贡献(当然要减子树内的)。
怎么求这个东西呢,sam不好做,但是sam的parent树本质上就是一棵后缀树,所以能通过反串的parent树建出正串的后缀树。然后就可以从根节点开始走,将所有(z->y)的串走到的位置都标记一遍,最后遍历整棵后缀树,将标记下传。
这个暴力显然可以用点分治优化,然而每一层都要扫一遍后缀树是在是太慢了,换句话说,当子树较小时,完全没必要扫一整棵后缀树,怎么办?暴力1!。
所以当当前子树
size<m−−√
s
i
z
e
<
m
用做法一得到答案 ,否则就用做法2,这样就可以做到
O((n+m)m−−√)
O
(
(
n
+
m
)
m
)
code:
#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#define LL long long
using namespace std;
int n,m;
char str[100010],S[100010];
LL ans=0;
struct node{
int y,next;
}a[100010];int len=0,last[50010];
int sqr,id1[50010],id2[50010];
void ins(int x,int y)
{
a[++len].y=y;
a[len].next=last[x];last[x]=len;
}
int sum,root,vis[50010];
struct trnode{
int son,tot;
}tr[50010];
struct SAM
{
int root,tail,tot,ch[100010][26],par[100010],max[100010],ri[100010],le[100010],s[100010],tag[100010],son[100010][26];
int addsam(int c,int len)
{
int p=tail,np=++tot;
max[np]=len;ri[np]=1;le[np]=len;
for(;p&&!ch[p][c];p=par[p]) ch[p][c]=np;
tail=np;
if(!p) par[np]=root;
else
{
int q=ch[p][c];
if(max[q]==max[p]+1) par[np]=q;
else
{
int nq=++tot;
memcpy(ch[nq],ch[q],sizeof(ch[q]));par[nq]=par[q];
max[nq]=max[p]+1;le[nq]=le[q];
par[q]=par[np]=nq;
for(;p&&ch[p][c]==q;p=par[p]) ch[p][c]=nq;
}
}
return np;
}
int sum[100010],X[100010];
void build()
{
for(int i=1;i<=tot;i++) sum[max[i]]++;
for(int i=1;i<=tot;i++) sum[i]+=sum[i-1];
for(int i=tot;i>=1;i--) X[sum[max[i]]--]=i;
for(int i=tot;i>=2;i--)
{
int x=X[i],fa=par[x];
ri[fa]+=ri[x];
son[fa][s[le[x]-max[fa]]]=x;
}
}
void mark(int x,int fa,int now,int len)
{
if(!now) return;
if(len==max[now]) now=son[now][str[x]-'a'];
else if(s[le[now]-len]!=str[x]-'a') now=0;
if(!now) return;
tag[now]++;
for(int i=last[x];i;i=a[i].next)
if(a[i].y!=fa&&!vis[a[i].y]) mark(a[i].y,x,now,len+1);
}
void push() {for(int i=1;i<=tot;i++) tag[X[i]]+=tag[par[X[i]]];}
}sam1,sam2;
void find_root(int x,int fa)
{
tr[x].tot=1;tr[x].son=0;
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(y==fa||vis[y]) continue;
find_root(y,x);tr[x].tot+=tr[y].tot;
if(tr[x].son<tr[y].tot) tr[x].son=tr[y].tot;
}
tr[x].son=max(tr[x].son,sum-tr[x].tot);
if(!root||tr[x].son<tr[root].son) root=x;
}
int num=0,g[50010];
void get(int x,int fa)
{
g[++num]=x;
for(int i=last[x];i;i=a[i].next)
if(a[i].y!=fa&&!vis[a[i].y]) get(a[i].y,x);
}
void get_sum(int x,int fa)
{
sum++;
for(int i=last[x];i;i=a[i].next)
if(a[i].y!=fa&&!vis[a[i].y]) get_sum(a[i].y,x);
}
void dfs(int x,int fa,int now)
{
now=sam1.ch[now][str[x]-'a'];
if(!now) return;
ans+=(LL)sam1.ri[now];
for(int i=last[x];i;i=a[i].next)
if(a[i].y!=fa&&!vis[a[i].y]) dfs(a[i].y,x,now);
}
void work(int x,int fa,int op)
{
for(int i=1;i<=sam1.tot;i++) sam1.tag[i]=0;
for(int i=1;i<=sam2.tot;i++) sam2.tag[i]=0;
int to=str[fa]-'a';
if(fa) sam1.mark(x,fa,sam1.son[1][to],1),sam2.mark(x,fa,sam2.son[1][to],1);
else sam1.mark(x,fa,1,0),sam2.mark(x,fa,1,0);
sam1.push();sam2.push();
for(int i=1;i<=m;i++) ans+=(LL)op*sam1.tag[id1[i]]*sam2.tag[id2[m-i+1]];
}
void solve(int x)
{
if(sum<=sqr)
{
num=0;get(x,0);
for(int i=1;i<=num;i++) dfs(g[i],0,sam1.root);
return;
}
vis[x]=1;
work(x,0,1);
for(int i=last[x];i;i=a[i].next)
if(!vis[a[i].y]) work(a[i].y,x,-1);
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(vis[y]) continue;
sum=0;get_sum(y,x);
root=0;find_root(y,x);
solve(root);
}
}
int main()
{
scanf("%d %d",&n,&m);
for(int i=1;i<n;i++)
{
int x,y;scanf("%d %d",&x,&y);
ins(x,y);ins(y,x);
}
sam1.tot=sam1.root=sam1.tail=1;
sam2.tot=sam2.root=sam2.tail=1;
scanf("%s",str+1);scanf("%s",S+1);
for(int i=1;i<=m;i++) sam1.s[i]=S[i]-'a',id1[i]=sam1.addsam(S[i]-'a',i);
reverse(S+1,S+m+1);
for(int i=1;i<=m;i++) sam2.s[i]=S[i]-'a',id2[i]=sam2.addsam(S[i]-'a',i);
sam1.build();sam2.build();
root=0;sum=n;sqr=sqrt(n);
find_root(1,0);
solve(root);
printf("%lld",ans);
}