sol:
发现题目就是给了一颗字典树让你求所有点到根组成的子串中本质不同的子串个数。用树上sa就行了。
树上sa的核心就是储存下rank[i][j]表示i这个点向上走2^j-1这么多步后,在这么多步的子串中的排名。这是因为在树上hei是没有O(n)求解的那个性质的,所以要用储存的rank数组来log求sa[i]与sa[i-1]的lcp。
#include<iostream>
#include<algorithm>
#include<cstring>
#include<string>
#include<cstdio>
#include<cstdlib>
#include<cmath>
using namespace std;
const int N=1e6+7;
const int logN=18;
int n,m,tot;
int sa[N],rank[logN][N],w[N],x[N],tmp[N];
int fir[N],go[N],nex[N],f[logN][N];
inline int read()
{
char c;
bool flag=false;
while((c=getchar())>'9'||c<'0')
if(c=='-')flag=true;
int res=c-'0';
while((c=getchar())>='0'&&c<='9')
res=(res<<3)+(res<<1)+c-'0';
return flag?-res:res;
}
inline void Sa()
{
int u,v,m=n;
for(int i=1;i<=n;++i) w[x[i]]++;
for(int i=1;i<=m;++i) w[i]+=w[i-1];
for(int i=n;i>=1;--i) sa[w[x[i]]--]=i;
rank[0][sa[1]]=1;
m=1;
for(int i=2;i<=n;++i)
{
u=sa[i];v=sa[i-1];
if(x[u]!=x[v]) m++;
rank[0][u]=m;
}
for(int j=0;j<logN;++j)
{
for(int i=0;i<=m;++i) w[i]=0;
for(int i=1;i<=n;++i) w[rank[j][f[j][i]]]++;
for(int i=1;i<=m;++i) w[i]+=w[i-1];
for(int i=1;i<=n;++i)
tmp[w[rank[j][f[j][i]]]--]=i;
for(int i=0;i<=m;++i) w[i]=0;
for(int i=1;i<=n;++i) w[rank[j][i]]++;
for(int i=1;i<=m;++i) w[i]+=w[i-1];
for(int i=n;i>=1;--i)
sa[w[rank[j][tmp[i]]]--]=tmp[i];
rank[j+1][sa[1]]=1;
m=1;
for(int i=2;i<=n;++i)
{
u=sa[i];v=sa[i-1];
if(rank[j][u]!=rank[j][v]||rank[j][f[j][u]]!=rank[j][f[j][v]]) m++;
rank[j+1][u]=m;
}
if(m==n) break;
}
}
inline void add_edge(int a,int b)
{
nex[++tot]=fir[a];fir[a]=tot;go[tot]=b;x[b]++;
nex[++tot]=fir[b];fir[b]=tot;go[tot]=a;x[a]++;
}
int dep[N];
void dfs(int u)
{
int v,e;
dep[u]=dep[f[0][u]]+1;
for(e=fir[u];v=go[e],e;e=nex[e])
if(f[0][u]!=v)
{
f[0][v]=u;
dfs(v);
}
}
inline void rmq()
{
for(int j=1;j<logN;++j)
for(int i=1;i<=n;++i)
f[j][i]=f[j-1][f[j-1][i]];
}
int ans;
inline void Ans()
{
int a,b;
for(int i=1;i<=n;++i) ans+=dep[i];
for(int i=2;i<=n;++i)
{
a=sa[i-1];b=sa[i];
for(int j=17;j>=0;--j)
if((1<<j)<=dep[b]&&rank[j][a]==rank[j][b])
{
ans-=1<<j;
a=f[j][a];b=f[j][b];
if(!a||!b) break;
}
}
}
int main()
{
freopen("route.in","r",stdin);
freopen("route.out","w",stdout);
n=read();
int a,b;
for(int i=1;i<n;++i)
{
a=read();
b=read();
add_edge(a,b);
}
dfs(1);
rmq();
Sa();
Ans();
printf("%d",ans);
}