根据期望的线性性,我们算出每个点期望被计算次数,然后进行累加.
考虑点 $x$ 对点 $y$ 产生了贡献,那么说明 $(x,y)$ 之间的点中 $x$ 是第一个被删除的.
这个期望就是 $\frac{1}{dis(x,y)+1}$,所以我们只需求 $\sum_{i=1}^{n}\sum_{j=1}^{n}\frac{1}{dis(i,j)+1}$ 即可.
然后这个直接求是求不出来的,所以需要用点分治+FFT来算树上每种距离都出现了多少次.
code:
#include <bits/stdc++.h>
using namespace std;
#define N 500003
#define ll long long
#define setIO(s) freopen(s".in","r",stdin)
const double pi=acos(-1);
ll ans[N];
int edges,root,sn,n,mxdep;
int size[N],mx[N],hd[N],to[N<<1],nex[N<<1],vis[N];
inline void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
struct cpx
{
double x,y;
cpx(double a=0,double b=0) { x=a,y=b; }
cpx operator+(const cpx b) { return cpx(x+b.x,y+b.y); }
cpx operator-(const cpx b) { return cpx(x-b.x,y-b.y); }
cpx operator*(const cpx b) { return cpx(x*b.x-y*b.y,x*b.y+y*b.x); }
}A[N],B[N];
void fft(cpx *a,int len,int flag)
{
int i,j,k,mid;
for(i=k=0;i<len;++i)
{
if(i>k) swap(a[i],a[k]);
for(j=len>>1;(k^=j)<j;j>>=1);
}
for(mid=1;mid<len;mid<<=1)
{
cpx wn(cos(pi/mid), flag*sin(pi/mid)),x,y;
for(i=0;i<len;i+=mid<<1)
{
cpx w(1,0);
for(j=0;j<mid;++j)
{
x=a[i+j],y=w*a[i+j+mid];
a[i+j]=x+y;
a[i+j+mid]=x-y;
w=w*wn;
}
}
}
if(flag==-1) for(int i=0;i<len;++i) a[i].x/=(double)len;
}
void getroot(int u,int ff)
{
size[u]=1,mx[u]=0;
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff||vis[v]) continue;
getroot(v,u);
size[u]+=size[v];
mx[u]=max(mx[u], size[v]);
}
mx[u]=max(mx[u], sn-size[u]);
if(mx[u]<mx[root]) root=u;
}
void dfs(int u,int ff,int d)
{
++A[d].x;
mxdep=max(mxdep,d);
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff||vis[v]) continue;
dfs(v,u,d+1);
}
}
void calc(int u,int d)
{
mxdep=0;
dfs(u,0,d==1?0:1);
int len=1;
while(len<=(mxdep+mxdep+2)) len<<=1;
fft(A,len,1);
for(int i=0;i<len;++i) A[i]=A[i]*A[i];
fft(A,len,-1);
for(int i=0;i<min(len,n);++i) ans[i]+=(ll)(A[i].x+0.1)*d;
for(int i=0;i<len;++i) A[i].x=A[i].y=0;
}
void solve(int u)
{
vis[u]=1;
calc(u,1);
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(vis[v]) continue;
calc(v,-1);
}
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(vis[v]) continue;
root=0,sn=size[v],getroot(v,u);
solve(root);
}
}
int main()
{
// setIO("input");
int i,j;
scanf("%d",&n);
for(i=1;i<n;++i)
{
int x,y;
scanf("%d%d",&x,&y);
++x,++y;
add(x,y),add(y,x);
}
mx[0]=sn=n;
getroot(1,0);
solve(root);
double tmp=0.0;
for(int i=0;i<n;++i)
{
tmp+=(double) ans[i]/(i+1);
}
printf("%.4f\n",tmp);
return 0;
}