题意:给定一个树,问任意两点间路径颜色排列构成的方案数,保证叶子节点数 <= 20。
思路:任意一条路径都可表示为从叶子节点下来的一条路径上的子串。(太妙了)所以我们对20个trie树插入广义后缀自动机。
#include <bits/stdc++.h>
using namespace std;
const int maxn =1000005;
const int maxm =2000005;
const int ch_size = 10;
struct SAM{
int fa[maxn<<1], // 后缀链接
ch[maxn<<1][ch_size],
len[maxn<<1], //该节点最长串的长度
tot, // 节点总数
last; // 代表当前的整个串
void init(){
tot = last = 0;
fa[0] = -1;
len[0] = 0;
memset( ch[0],0,sizeof( ch[0] ) );
}
void extend( int x ){
int p = last;
if( ch[p][x] ){
int q = ch[p][x];
if( len[q] == len[p]+1 ) last = q;
else{
int clone = ++tot;
len[clone] = len[p] + 1;
last = clone;
for (int i = 0; i < ch_size; i++) {
ch[clone][i] = ch[q][i];
}
fa[clone] = fa[q];
fa[q] = clone;
while (p != -1) {
if (ch[p][x] == q)ch[p][x] = clone;
else break;
p = fa[p];
}
}
return;
}
int cur = ++tot;
memset( ch[cur],0,sizeof( ch[cur] ) );
len[cur] = len[last]+1;
while( p != -1 && !ch[p][x] ){
ch[p][x] = cur;
p = fa[p];
}
if( p == -1 ){
fa[cur] = 0;
}else {
int q = ch[p][x];
if (len[q] == len[p] + 1) {
fa[cur] = q;
} else {
int clone = ++tot;
len[clone] = len[p] + 1;
for (int i = 0; i < ch_size; i++) {
ch[clone][i] = ch[q][i];
}
fa[clone] = fa[q];
fa[q] = fa[cur] = clone;
while (p != -1) {
if (ch[p][x] == q)ch[p][x] = clone;
else break;
p = fa[p];
}
}
}
last = cur;
}
}g;
int a[maxn],ver[maxm],ne[maxm],he[maxn],tot;
void init(){
tot = 1;
memset( he,0,sizeof(he));
}
void add( int x,int y ){
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
void build( int x,int fa,int last ){
g.last = last;
g.extend(a[x]);
last = g.last;
for( int cure = he[x];cure;cure = ne[cure] ){
int y = ver[cure];if( y == fa ) continue;
g.last = last;
build( y,x,last );
}
}
int du[maxn];
int main(){
init();g.init();
int n,c;
scanf("%d%d",&n,&c);
for( int i = 1; i<= n;i++ ) scanf("%d",&a[i]);
for( int i = 1;i < n;i++ ){
int x,y;scanf("%d%d",&x,&y);du[x]++;du[y]++;
add(x,y);add(y,x);
}
for( int i = 1; i <= n;i++ ){
if( du[i] == 1 ){
g.last = 0;
build( i,0,0 );
}
}
long long ans = 0;
for( int i = 1;i <= g.tot;i++ ){
ans = ans + g.len[i]-g.len[ g.fa[i] ];
}
printf("%lld",ans);
return 0;
}