题意简述
有
n
点的树,边权都为
数据范围
1≤n≤5×104
思路
点分治+FFT。
只需要求出所有长度为素数的路径条数。
点分治统计答案的时候,我们构建一个生成函数,表示每种距离的条数。
将其平方就是答案,这一步我们可以利用FFT高效实现。
但是多统计了在同一子树的方案,用同样的方法计算减去。
时间复杂度
O(nlog2n)
。
大常数需要跑1.4s。
#include<cstdio>
#include<cstring>
#include<cmath>
#include<complex>
using namespace std;
#define MAXN 50010
#define pi acos(-1)
#define MAXL 131080
typedef complex<double> C;
struct edge{
int s,t,next;
}e[MAXN<<1];
int head[MAXN],cnt;
void addedge(int s,int t)
{
e[cnt].s=s;e[cnt].t=t;e[cnt].next=head[s];head[s]=cnt++;
e[cnt].s=t;e[cnt].t=s;e[cnt].next=head[t];head[t]=cnt++;
}
int n,all,rt,u,v,f[MAXN],size[MAXN],num[MAXN];
bool vis[MAXN];
int len,tmp,ti;
int r[MAXL];
C a[MAXL];
int p_cnt;
int prime[MAXN];
bool not_prime[MAXN];
long long ans;
void fft(C *a,int f)
{
for (int i=0;i<len;i++)
if (i<r[i])
swap(a[i],a[r[i]]);
for (int i=1;i<len;i<<=1)
{
C wn(cos(pi/i),f*sin(pi/i));
for (int j=0;j<len;j+=(i<<1))
{
C w=1;
for (int k=0;k<i;k++,w*=wn)
{
C x=a[j+k],y=w*a[i+j+k];
a[j+k]=x+y,a[i+j+k]=x-y;
}
}
}
}
void get_rt(int node,int lastfa)
{
f[node]=0;
size[node]=1;
for (int i=head[node];i!=-1;i=e[i].next)
if (e[i].t!=lastfa&&!vis[e[i].t])
{
get_rt(e[i].t,node);
size[node]+=size[e[i].t];
f[node]=max(f[node],size[e[i].t]);
}
f[node]=max(f[node],all-size[node]);
if (f[node]<f[rt])
rt=node;
}
void get_dis(int node,int lastfa,int sum,int f)
{
num[sum]+=f;
for (int i=head[node];i!=-1;i=e[i].next)
if (e[i].t!=lastfa&&!vis[e[i].t])
get_dis(e[i].t,node,sum+1,f);
}
void solve(int node)
{
vis[node]=1;
for (int i=head[node];i!=-1;i=e[i].next)
if (!vis[e[i].t])
{
get_dis(e[i].t,node,1,1);
for (tmp=size[e[i].t]<<1,len=1,ti=0;len<=tmp;len<<=1)
ti++;
for (int j=0;j<len;j++)
r[j]=(r[j>>1]>>1)|((j&1)<<(ti-1));
for (int j=0;j<=size[e[i].t];j++)
a[j]=num[j];
for (int j=size[e[i].t]+1;j<len;j++)
a[j]=0;
fft(a,1);
for (int j=0;j<len;j++)
a[j]=a[j]*a[j];
fft(a,-1);
for (int j=0;j<len;j++)
a[j]/=len;
for (int j=1;j<=p_cnt&&prime[j]<len;j++)
ans-=(long long)(a[prime[j]].real()+0.5);
get_dis(e[i].t,node,1,-1);
}
get_dis(node,node,0,1);
for (tmp=size[node]<<1,len=1,ti=0;len<=tmp;len<<=1)
ti++;
for (int i=0;i<len;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(ti-1));
for (int i=0;i<=size[node];i++)
a[i]=num[i];
for (int i=size[node]+1;i<len;i++)
a[i]=0;
fft(a,1);
for (int i=0;i<len;i++)
a[i]=a[i]*a[i];
fft(a,-1);
for (int i=0;i<len;i++)
a[i]/=len;
for (int i=1;i<=p_cnt&&prime[i]<len;i++)
ans+=(long long)(a[prime[i]].real()+0.5);
get_dis(node,node,0,-1);
for (int i=head[node];i!=-1;i=e[i].next)
if (!vis[e[i].t])
{
all=size[e[i].t];
rt=0;
get_rt(e[i].t,e[i].t);
get_rt(rt,rt);
solve(rt);
}
}
void sieve(int n)
{
for (int i=2;i<=n;i++)
{
if (!not_prime[i])
prime[++p_cnt]=i;
for (int j=1;j<=p_cnt&&prime[j]*i<=n;j++)
{
not_prime[i*prime[j]]=1;
if (i%prime[j]==0)
break;
}
}
}
int main()
{
scanf("%d",&n);
sieve(n);
memset(head,0xff,sizeof(head));
cnt=0;
for (int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
addedge(u,v);
}
rt=0;
f[rt]=n+1;
all=n;
get_rt(1,1);
get_rt(rt,rt);
solve(rt);
printf("%.8lf",1.0*ans/(1.0*n*(n-1)));
return 0;
}