因为叶子节点很少 我们把20个trie合并成广义后缀自动机
然后直接跑就好了
trie树构建自动机的时候 回溯的时候把last还原
#include<cstdio>
#include<cstdlib>
#include<set>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
if (p1==p2) { p2=(p1=buf)+fread(buf,1,100000,stdin); if (p1==p2) return EOF; }
return *p1++;
}
inline void read(int &x){
char c=nc(),b=1;
for (;!(c>='0' && c<='9');c=nc()) if (c=='-') b=-1;
for (x=0;c>='0' && c<='9';x=x*10+c-'0',c=nc()); x*=b;
}
const int M=2000005;
int C;
struct state{
int link,len,next[10];
}st[M];
int ncnt,last;
inline void Extend(char c){
int cur=st[last].next[c];
if (cur){
if (st[cur].len==st[last].len+1)
last=cur;
else{
int r=++ncnt;
st[r].len=st[last].len+1;
st[r].link=st[cur].link; st[cur].link=r;
for (int i=0;i<C;i++) st[r].next[i]=st[cur].next[i];
for (int p=last;p!=-1 && st[p].next[c]==cur;p=st[p].link)
st[p].next[c]=r;
last=r;
}
}else{
cur=++ncnt; int p;
st[cur].len=st[last].len+1;
for (p=last;p!=-1 && !st[p].next[c];p=st[p].link)
st[p].next[c]=cur;
if (p==-1)
st[cur].link=0;
else{
int q=st[p].next[c];
if (st[q].len==st[p].len+1)
st[cur].link=q;
else{
int nq=++ncnt;
st[nq].len=st[p].len+1;
st[nq].link=st[q].link;
for (int i=0;i<C;i++) st[nq].next[i]=st[q].next[i];
for (;p!=-1 && st[p].next[c]==q;p=st[p].link)
st[p].next[c]=nq;
st[q].link=st[cur].link=nq;
}
}
last=cur;
}
}
const int N=100005;
struct edge{
int u,v,next;
}G[N<<1];
int head[N],inum;
inline void add(int u,int v,int p){
G[p].u=u; G[p].v=v; G[p].next=head[u]; head[u]=p;
}
int val[N];
#define V G[p].v
inline void dfs(int u,int fa){
Extend(val[u]);
int t=last;
for (int p=head[u];p;p=G[p].next)
if (V!=fa)
dfs(V,u),last=t;
}
int n;
int deg[N];
int main(){
int iu,iv;
freopen("t.in","r",stdin);
freopen("t.out","w",stdout);
read(n); read(C);
for (int i=1;i<=n;i++) read(val[i]);
for (int i=1;i<n;i++)
read(iu),read(iv),add(iu,iv,++inum),add(iv,iu,++inum),deg[iu]++,deg[iv]++;
st[0].link=-1;
for (int i=1;i<=n;i++)
if (deg[i]==1)
last=0,dfs(i,0);
ll ans=0;
for (int i=1;i<=ncnt;i++)
ans+=st[i].len-st[st[i].link].len;
printf("%lld\n",ans);
return 0;
}