3926: [Zjoi2015]诸神眷顾的幻想乡
Time Limit: 10 Sec Memory Limit: 512 MBSubmit: 1631 Solved: 947
[ Submit][ Status][ Discuss]
Description
幽香是全幻想乡里最受人欢迎的萌妹子,这天,是幽香的2600岁生日,无数幽香的粉丝到了幽香家门前的太阳花田上来为幽香庆祝生日。
Input
第一行两个正整数n,c。表示空地数量和颜色数量。
Output
一行,输出一个整数,表示答案。
Sample Input
0 2 1 2 1 0 0
1 2
3 4
3 5
4 6
5 7
2 5
Sample Output
HINT
对于所有数据,1<=n<=100000, 1<=c<=10。
Source
所以说标题这么霸气跟内容有什么关系咩?
题目大意就是统计树上不同串的个数... 发现我只会做一个串的于是默默打开了题解... 广义后缀自动机? 感觉什么东西加了广义都变得高大上了? 实际上广义后缀自动机就是处理多个串的情况.
首先一个比较显然(但是没想到啊)的结论就是原树上的任意一个子串一定会直线(深度单调递增)存在于某个原树里的叶子节点提起来变成的树里. 深度单调可以干嘛? 我们可以发现题目中给出的实际上就是一个trie, 然而我们建出一个trie的后缀自动机可以得到这个trie树里任意一个深度单调递增的子串(这个好像就已经叫广义后缀自动机了?). 那么把每个叶子节点作为根的trie都插入后缀自动机里就可以了... 这样一定可以在这个后缀自动机里表示出原树里的每个子串. 实际实现过程中, 一个trie的u节点插入的时候,last就是trie树上的父节点(正确性显然?). 每一个trie刚插进来的时候都从root开始. 这道题由于叶子结点最多就20个, 所以复杂度不会爆棚.
统计答案就很naive了, 跟统计普通串一样的. 根据parent树上的性质, 统计len[i] - len[par[i]].
#include<bits/stdc++.h>
using namespace std;
typedef long long lnt;
const int maxm = 1e5 + 5;
const int maxn = 4e6 + 5;
lnt ans;
int n, cc, tot, num, root;
int a[maxn][10], h[maxm], dep[maxn], par[maxn], in[maxm], col[maxm];
struct edge { int nxt, v; }e[maxm << 1];
inline const int read() {
register int x = 0;
register char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x;
}
inline void add(int u, int v) {
in[u] ++, in[v] ++;
e[++ num].v = v, e[num].nxt = h[u], h[u] = num;
e[++ num].v = u, e[num].nxt = h[v], h[v] = num;
}
inline int insert(int p, int c) {
int np = ++ tot;
dep[np] = dep[p] + 1;
while (p && !a[p][c]) a[p][c] = np, p = par[p];
if (!p) par[np] = root;
else {
int q = a[p][c];
if (dep[q] == dep[p] + 1) par[np] = q;
else {
int nq = ++ tot;
par[nq] = par[q];
dep[nq] = dep[p] + 1;
memcpy(a[nq], a[q], sizeof(a[q]));
par[np] = par[q] = nq;
while (a[p][c] == q) a[p][c] = nq, p = par[p];
}
}
return np;
}
void dfs(int u, int fa, int pre) {
int now = insert(pre, col[u]);
for (int i = h[u]; i; i = e[i].nxt)
if (e[i].v != fa) dfs(e[i].v, u, now);
}
int main() {
int u, v;
root = ++ tot;
n = read(), cc = read();
for (int i = 1; i <= n; ++ i) col[i] = read();
for (int i = 1; i < n; ++ i)
u = read(), v = read(), add(u, v);
for (int i = 1; i <= n; ++ i)
if (in[i] == 1) dfs(i, 0, 1);
for (int i = 2; i <= tot; ++ i)
ans += (lnt)dep[i] - dep[par[i]];
printf("%lld\n", ans);
}