Description
A君准备在Z国进行一次旅行,Z国中有n个城市,城市从1到n进行编号,其中1号城市为Z国首都。Z国的旅行交通网由n-1条单向道路构成,并且从任何一个城市出发都可以通过旅行网到达首都。
一条旅行交通网中的旅行线路,可以用线路上所经过的城市来描述,如{v1,v2,v3,……,vm},它表示一条经过了m个城市的旅行路线,且城市vi到城市vi+1有一条单向道路相连。
两个城市是相似的,当且仅当他们所连接的道路数相同。
若两条路线{u1,u2,……,up}与{v1,v2,……,vq},若p=q且∀1 ≤ i ≤ p,城市 u i 与 v i 是相似的,则 A君认为这两条旅行路线也是相似的。
现在A君想知道共有多少种不同的旅行路线,相似的若干条旅行路线只算一种。
Input
第一行一个整数n表示Z国城市个数
接下来n-1行每行两个整数x,y,表示一条从x到y的单向道路
Output
仅一行一个整数表示答案
Sample Input
3
2 1
3 1
Sample Output
3
Data Constraint
20%的数据:n ≤ 100
另有40%的数据:每个城市所连接的道路不超过20条
100%的数据:1≤n≤10^5
分析
简单的来说就是给一颗字典树,然后让你求上面有多少个不同的子串(不一定从根节点开始)。
那么只要在这棵字典树上建一棵广义的后缀自动机,然后把每个节点对应的字符串数量加起来即可。
代码
#include <bits/stdc++.h>
#define N 100005
struct NOTE
{
int to,next;
}e[N];
int cnt;
int next[N];
int n;
int d[N];
int fa[N * 2];
int max[N * 2];
std::map <int,int> ch[N * 2];
int read()
{
int x = 0,f = 1;
char ch = getchar();
while (ch < '0' || ch > '9')
{
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
x = x * 10 + ch - '0';
ch = getchar();
}
return x * f;
}
void add(int x,int y)
{
e[++cnt].to = y;
e[cnt].next = next[x];
next[x] = cnt;
}
int size;
int ins(int last,int x)
{
if (ch[last][x])
{
int p = last;
int np = ch[last][x];
if (max[np] == max[p] + 1)
last = np;
else
{
int q = ++size;
max[q] = max[p] + 1;
ch[q] = ch[np];
fa[q] = fa[np];
fa[np] = last = q;
for (;ch[p][x] == np; p = fa[p])
ch[p][x] = q;
}
return last;
}
int p,q,np,nq;
p = last;
last = np = ++size;
max[np] = max[p] + 1;
for (; !ch[p][x] && p; p = fa[p])
ch[p][x] = np;
if (!p)
fa[np] = 1;
else
{
q = ch[p][x];
if (max[q] == max[p] + 1)
fa[np] = q;
else
{
nq = ++size;
max[nq] = max[p] + 1;
ch[nq] = ch[q];
fa[nq] = fa[q];
fa[q] = fa[np] = nq;
for (; ch[p][x] == q; p = fa[p])
ch[p][x] = nq;
}
}
return last;
}
void dfs(int x,int p)
{
int tmp = ins(p,d[x]);
for (int i = next[x]; i; i = e[i].next)
dfs(e[i].to,tmp);
}
int tmp[N];
int tot;
void ls()
{
for (int i = 1; i <= n; i++)
tmp[++tot] = d[i];
std::sort(tmp + 1,tmp + tot + 1);
tot = std::unique(tmp + 1,tmp + tot + 1) - tmp - 1;
for (int i = 1; i <= n; i++)
d[i] = std::lower_bound(tmp + 1,tmp + tot + 1,d[i]) - tmp - 1;
}
int main()
{
freopen("route.in","r",stdin);
freopen("route.out","w",stdout);
n = read();
for (int i = 1; i < n; i++)
{
int x = read();
int y = read();
d[x]++;
d[y]++;
add(y,x);
}
ls();
size = 1;
dfs(1,1);
long long ans = 0;
for (int i = 1; i <= size; i++)
ans += max[i] - max[fa[i]];
printf("%lld\n",ans);
}