%%%神犇的题解
转移不难想 关键是空间和时间
空间用轻重链的思想和指针转移
时间么 指针转移&启发式合并
复杂度的证明比较巧妙
“任意设一点作为根。令 f(a, d) 表示在以 a 点为根的子树中,与 a 距离为 d 的节点数;g(a, d) 表示在以
a 为根的子树中选择两个节点,满足剩下的一个节点 s 需在 a 子树外选择且与 a 的距离必须为 d 的方案数。
则方案数容易统计,且不难写出 DP 转移式。
注意到 f(a, d), g(a, d) 满足 d ≤ size(a),则转移式可以用启发式合并优化。再注意到若 a 只有一个儿子
b ,则:
f(a, x) = f(b, x ? 1)
g(a, x) = g(b, x + 1)
即 f(a) = f(b) ? 1, g(a) = g(b) + 1,这一步可以利用指针在 O(1) 内实现。则可以在 O(1) 时间内将某个
点 size 最大的儿子的 DP 数组转移到该点上。利用类似树链剖分的方式可以预处理出所需要的数组空间。
这个看似是nlog的
但实际上如果一个点v 在某个祖先u 对复杂度有+1的效果,说明v 不在u往下的最长链上,并且v 在 u 的儿子u’(也是 v 的祖先)的最长链上。则到u 的父亲时,v 不在 u的最长链上,则对答案没有贡献。所以贡献的总和是 O(n)的。”
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#define V G[p].v
using namespace std;
typedef long long ll;
inline char nc()
{
static char buf[100000],*p1=buf,*p2=buf;
if (p1==p2) { p2=(p1=buf)+fread(buf,1,100000,stdin); if (p1==p2) return EOF; }
return *p1++;
}
inline void read(int &x){
char c=nc(),b=1;
for (;!(c>='0' && c<='9');c=nc()) if (c=='-') b=-1;
for (x=0;c>='0' && c<='9';x=x*10+c-'0',c=nc()); x*=b;
}
const int N=100005;
struct edge{
int u,v,next;
};
edge G[N<<1];
int head[N],inum;
inline void add(int u,int v,int p){
G[p].u=u; G[p].v=v; G[p].next=head[u]; head[u]=p;
}
int n;
int depth[N],bot[N],son[N];
ll Mem[N*10],*pnt=Mem;
ll *f[N],*g[N];
ll ans;
inline void dfs(int u,int fa){
bot[u]=u; depth[u]=depth[fa]+1;
for (int p=head[u];p;p=G[p].next)
if (V!=fa){
dfs(V,u);
if (depth[bot[V]]>depth[bot[u]])
bot[u]=bot[V],son[u]=V;
}
for (int p=head[u];p;p=G[p].next)
if (V!=fa && (u==1 || V!=son[u])){
int v=bot[V];
f[v]=(pnt+=depth[v]-depth[u]+1);
g[v]=(++pnt);
pnt+=((depth[v]-depth[u])<<1)+10;
}
}
inline void dp(int u,int fa){
for (int p=head[u];p;p=G[p].next)
if (V!=fa){
dp(V,u);
if (V==son[u])
f[u]=f[V]-1,g[u]=g[V]+1;
}
ans+=g[u][0]; f[u][0]=1;
for (int p=head[u];p;p=G[p].next)
if (V!=fa && V!=son[u]){
for(int j=0;j<=depth[bot[V]]-depth[u];j++)
ans+=f[u][j-1]*g[V][j]+g[u][j+1]*f[V][j];
for(int j=0;j<=depth[bot[V]]-depth[u];j++){
g[u][j-1]+=g[V][j];
g[u][j+1]+=f[u][j+1]*f[V][j];
f[u][j+1]+=f[V][j];
}
}
}
int main()
{
int iu,iv;
freopen("t.in","r",stdin);
freopen("t.out","w",stdout);
read(n);
for (int i=1;i<n;i++)
read(iu),read(iv),add(iu,iv,++inum),add(iv,iu,++inum);
dfs(1,0); dp(1,0);
printf("%lld\n",ans);
return 0;
}