求树上点对间距离为质数的对数。
做法:点分治,合并两条链发现是卷积形式,上fft。(注意清空,为这玩意调了大半天。。)
代码:
#include<bits/stdc++.h>
#define db double
#define ll long long
using namespace std;
const int N=1e5+10;
const db pi=acos(-1);
struct cp{
db r,i;
cp(){};
cp(db r,db i):r(r),i(i){};
};
cp operator +(const cp &x,const cp &y)
{return cp(x.r+y.r,x.i+y.i);}
cp operator -(const cp &x,const cp &y)
{return cp(x.r-y.r,x.i-y.i);}
cp operator *(const cp &x,const cp &y)
{return cp(x.r*y.r-x.i*y.i,x.r*y.i+x.i*y.r);}
cp w[N],A[N],B[N];
int len,cc,wh[N],mxd,nwd,pri[N],pnum=0;
vector<int>mp[N];
int n,rt,sz,siz[N],minx,cnt[N],dis[N];
bool vis[N],np[N];
ll ans;
void pre()
{
for(int i=2;i<=50000;i++)
{
if(!np[i])pri[++pnum]=i;
for(int j=1;j<=pnum&&i*pri[j]<=50000;j++)
{
np[i*pri[j]]=1;
if(i%pri[j]==0)break;
}
}
}
void init(int n,int m)
{
len=2,cc=1;
while(len<n+m-1)len<<=1,cc++;
for(int i=0;i<len;i++)
w[i]=cp(cos(2.0*pi*i/len),sin(2.0*pi*i/len));
w[len]=w[0];
for(int i=1;i<len;i++)
wh[i]=(wh[i>>1]>>1)|((i&1)<<(cc-1));
}
void fft(cp *a,bool inv)
{
for(int i=0;i<len;i++)
if(i<wh[i])swap(a[i],a[wh[i]]);
cp tp;
for(int l=2;l<=len;l<<=1)
{
for(int i=0,md=l/2;i<len;i+=l)
{
for(int j=0;j<md;j++)
{
tp=(inv?w[len-len/l*j]:w[len/l*j])*a[i+j+md];
a[i+j+md]=a[i+j]-tp;
a[i+j]=a[i+j]+tp;
}
}
}
}
int fft(cp *a,int n,cp *b,int m)
{
init(n,m);
for(int i=n;i<len;i++)a[i]=cp(0,0);
fft(a,0),fft(b,0);
for(int i=0;i<len;i++)
a[i]=a[i]*b[i];
fft(a,1);
for(int i=0;i<len;i++)
a[i].r=(int)(a[i].r/len+0.5);
int res=0;
for(int i=1;i<=pnum;i++)
{
if(pri[i]>=len)break;
res+=a[pri[i]].r;
}
for(int i=0;i<=len;i++)
a[i]=b[i]=cp(0,0);
return res;
}
void dfs(int pos,int fa)
{
int mx=0,v;
siz[pos]=1;
for(int i=0;i<mp[pos].size();i++)
{
v=mp[pos][i];
if(vis[v]||v==fa)continue;
dfs(v,pos),mx=max(mx,siz[v]),siz[pos]+=siz[v];
}
mx=max(mx,sz-siz[pos]);
if(mx<minx)minx=mx,rt=pos;
}
void Rt(int x)
{rt=x,minx=1e9,dfs(x,0);}
void dfs1(int pos,int fa,int dep)
{
int v;
dis[dep]++;
for(int i=0;i<mp[pos].size();i++)
{
v=mp[pos][i];
if(v==fa||vis[v])continue;
dfs1(v,pos,dep+1);
}
nwd=max(nwd,dep);
}
void work(int x)
{
int v,tp=1;
mxd=0;
for(int i=0;i<mp[x].size();i++)
{
v=mp[x][i];
if(vis[v])continue;
nwd=0;dfs1(v,x,1);
for(int j=1;j<=pnum;j++)
{
if(pri[j]>nwd)break;
ans+=dis[pri[j]];
}
for(int j=0;j<=nwd;j++)
B[j]=cp(dis[j],0);
for(int j=0;j<=mxd;j++)
A[j]=cp(cnt[j],0);
ans+=fft(A,mxd+1,B,nwd+1);
mxd=max(mxd,nwd);
for(int j=0;j<=nwd;j++)
cnt[j]+=dis[j],dis[j]=0;
}
for(int i=0;i<=siz[x];i++)cnt[i]=0;
return;
}
void sol(int x)
{
int v;
vis[x]=1,work(x);
for(int i=0;i<mp[x].size();i++)
{
v=mp[x][i];
if(vis[v])continue;
sz=siz[v],Rt(v),sol(rt);
}
}
int main()
{
int u,v;
pre();
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
mp[u].push_back(v);
mp[v].push_back(u);
}
sz=n,Rt(1),sol(rt);
ll tot=1LL*n*(n-1)/2;
printf("%.6lf\n",1.0*ans/tot);
}
/*
6
2 1
3 1
4 3
5 4
6 5
*/