题目给定一棵树,树上每个结点有一个特定的值,定义一个序列是任选两点 A , B A, B A,B(可以相同), A A A 点到 B B B 点之间的所有结点上的值构成的一个序列。若 A = B A=B A=B,自然这个序列长度就为 1 1 1啦。
题目问你有多少种本质不同的序列,有个小细节,题目说度为 1 1 1 的结点不超过 20 20 20 个,这里我们就可以从所有的叶子结点出发然后插入所有字符串。
要找本质不同的字符串数量,可以联想到后缀自动机求 m a x l e n s e n d p o s − m a x l e n s l i n k maxlens_{endpos}-maxlens_{link} maxlensendpos−maxlenslink,其总和就是本质不同的子串数量。然后这题树形结构上面插入字符串,也就是字符串的前缀相同,这就要我们及时更改 l a s t last last 标记。同时为了实现在线插入多串,我们用广义后缀自动机。
有个小点就是这题的数据范围,数组要开足够大,我因为数组不够大 W A WA WA 了好几发,应该预留 2 ∗ 20 ∗ n 2*20*n 2∗20∗n 的空间大小才正确(因为遍历了字符串 20 20 20 遍,最坏情况当做插入了 20 20 20 个不同的串,然后建后缀自动机再预留多一倍空间)。
代码如下:
#include<bits/stdc++.h>
#define endl '\n'
using namespace std;
typedef long long LL;
const int maxn = 4e6 + 5;
char s[maxn];
int sz, last, cnt; // sz = size
int head[maxn], color[maxn], out[maxn];
LL ans;
struct state{
int len, link;
int next[11];
}st[maxn];
struct node{
int v, next;
}edge[maxn];
inline void add(int u, int v){
edge[cnt].next = head[u];
edge[cnt].v = v;
head[u] = cnt++;
}
inline void init(){
memset(head, -1, sizeof(head));
st[0].len = 0;
st[0].link = -1;
sz = 0;
last = 0;
}
inline void extend(int c){
if(st[last].next[c]){
int p = last, x = st[last].next[c];
if(st[x].len == st[p].len + 1)
last = x;
else{
int y = ++sz;
st[y] = st[x];
st[y].len = st[p].len + 1;
st[x].link = y;
while(~p && st[p].next[c] == x){
st[p].next[c] = y;
p = st[p].link;
}
last = y;
}
return;
}
int now = ++sz;
st[now].len = st[last].len + 1;
int p = last;
while(~p && !st[p].next[c]){
st[p].next[c] = now;
p = st[p].link;
}
if(p == -1)
st[now].link = 0;
else{
int q = st[p].next[c];
if(st[p].len + 1 == st[q].len)
st[now].link = q;
else{
int clone = ++sz;
st[clone] = st[q];
st[clone].len = st[p].len + 1;
st[q].link = st[now].link = clone;
while(~p && st[p].next[c] == q){
st[p].next[c] = clone;
p = st[p].link;
}
}
}
last = now;
ans += st[last].len - st[st[last].link].len;
}
inline void dfs(int u, int fa){
extend(color[u]);
int now = last;
for(int k = head[u]; ~k; k = edge[k].next){
if(edge[k].v != fa){
last = now;
dfs(edge[k].v, u);
}
}
}
int main(){
cin.tie(0);
cout.tie(0);
ios::sync_with_stdio(false);
init();
int n, k, c, u, v;
cin >> n >> k;
for(int i = 1; i <= n; i++) cin >> color[i];
for(int i = 1; i < n; i++){
cin >> u >> v;
out[u]++;
out[v]++;
add(u, v);
add(v, u);
}
for(int i = 1; i <= n; i++){
if(out[i] == 1){
last = 0;
dfs(i, i);
}
}
cout << ans << endl;
}