题目描述
Problem description.
You are given a tree. If we select 2 distinct nodes uniformly at random, what’s the probability that the distance between these 2 nodes is a prime number?
Input
The first line contains a number N: the number of nodes in this tree.
The following N-1 lines contain pairs a[i] and b[i], which means there is an edge with length 1 between a[i] and b[i].
Output
Output a real number denote the probability we want.
You’ll get accept if the difference between your answer and standard answer is no more than 10^-6.
Constraints
2 ≤ N ≤ 50,000
The input must be a tree.
Example
Input:
5
1 2
2 3
3 4
4 5
Output:
0.5
Explanation
We have C(5, 2) = 10 choices, and these 5 of them have a prime distance:
1-3, 2-4, 3-5: 2
1-4, 2-5: 3
Note that 1 is not a prime number.
题目分析
使用树分治后我们能够得到任意点到重心的距离,现在问题转换为如何判断两点之间的距离为素数。
首先我们要使用线性筛处理出一个素数表。
一个简单的想法是我们将Dis
数组中不同的点加起来就是不同两点之间的距离。然后再判断这个距离是不是素数,如果是素数则将路径+1。可是这样时间复杂度太高(O(n*n)
),因此我们需要使用FFT进行加速,快速得到两路径的和。但是我们需要减去同一条路径走两边的情况。
需要求概率,我们就需要知道事件数是什么。对于这里来讲,我们需要考虑的是选择排列还是组合。由给出的样例我们看出,应该按照组合进行计算。因此对于每个路径和我们都应该/2
,这样得到的才是排列数。
需要注意的是每次FFT以后对数组的清零。我就是这里没有注意然后导致疯狂出错,还一直检查不出问题。
还有就是,树上的路径很多,很容易爆long long
AC代码
#include<iostream>
#include<cstring>
#include<cstdio>
#include<climits>
#include<algorithm>
#include<ctime>
#include<cstdlib>
#include<queue>
#include<set>
#include<map>
#include<cmath>
using namespace std;
const int MAXN=5e5+5;
const double PI=acos(-1.0);
typedef long long ll;
struct edge
{
int to,len,last;
}Edge[MAXN<<2]; int Last[MAXN],tot;
int n,kk,SonNum[MAXN],MaxNum[MAXN],Vis[MAXN],Dis[MAXN];
int Prime[MAXN]; bool IsPrime[MAXN]; int prime_num=0;
int root,rootx,dlen,ss;
ll Num[MAXN<<2],MaxLen,len;
ll Res[MAXN<<2];
ll ans;
struct complex
{
double r,i;
complex(double _r=0,double _i=0):r(_r),i(_i){}
complex operator +(const complex &b)
{
return complex(r+b.r,i+b.i);
}
complex operator -(const complex &b)
{
return complex(r-b.r,i-b.i);
}
complex operator *(const complex &b)
{
return complex(r*b.r-i*b.i,r*b.i+i*b.r);
}
}A[MAXN<<2];
void change(complex y[],int len)
{
int i,j,k;
for(i = 1, j = len/2;i < len-1;i++)
{
if(i < j)swap(y[i],y[j]);
k = len/2;
while( j >= k)
{
j -= k;
k /= 2;
}
if(j < k)j += k;
}
}
void fft(complex y[],int len,int on)
{
change(y,len);
for(int h = 2;h <= len;h <<= 1)
{
complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
for(int j = 0;j < len;j += h)
{
complex w(1,0);
for(int k = j;k < j+h/2;k++)
{
complex u = y[k];
complex t = w*y[k+h/2];
y[k] = u+t;
y[k+h/2] = u-t;
w = w*wn;
}
}
}
if(on == -1)
for(int i = 0;i < len;i++)
y[i].r /= len;
}
void FFT(ll a[],int la,ll b[])//la,lb分别是a,b数组的最高位+1
{
//int len=1; while(len<la+la) len<<=1;
for(int i=0;i<la;++i) A[i]=complex(a[i],0);
for(int i=la;i<len;++i) A[i]=complex(0,0);
fft(A,len,1);
for(int i=0;i<len;++i) A[i]=A[i]*A[i];
fft(A,len,-1);
for(int i=0;i<len;++i) b[i]=(ll)(A[i].r+0.5);
}
void CreatPrime()
{
IsPrime[0]=IsPrime[1]=true; prime_num=0;
for(int i=2;i<MAXN;++i)
{
if(!IsPrime[i])
Prime[prime_num++]=i;
for(int j=0;j<prime_num && Prime[j]*i<MAXN;j++)
{
IsPrime[Prime[j]*i]=true;
if(i%Prime[j]==0) break;
}
}
}
int getint()
{
int x=0,sign=1; char c=getchar();
while(c<'0' || c>'9')
{
if(c=='-') sign=-1; c=getchar();
}
while(c>='0' && c<='9')
{
x=x*10+c-'0'; c=getchar();
}
return x*sign;
}
void Init()
{
//CreatPrime();
//for(int i=0;i<=n;++i) Last[i]=0;
memset(Last,-1,sizeof(Last));
tot=0;
ans=0;
}
void Clear()
{
for(int i=0;i<=n;++i) Vis[i]=false;
}
void AddEdge(int u,int v,int w)
{
Edge[++tot].to=v; Edge[tot].len=w;
Edge[tot].last=Last[u]; Last[u]=tot;
}
void Read()
{
n=getint();
int u,v;
for(int i=1;i<n;i++)
{
u=getint(); v=getint();
AddEdge(u,v,1); AddEdge(v,u,1);
}
}
void GetRoot(int x,int father)
{
int v;
SonNum[x]=1; MaxNum[x]=1;
for(int i=Last[x];~i;i=Edge[i].last)
{
v=Edge[i].to; if(v==father || Vis[v]) continue;
GetRoot(v,x);
SonNum[x]+=SonNum[v];
if(SonNum[v]>MaxNum[x]) MaxNum[x]=SonNum[x];
}
if(ss-SonNum[x]>MaxNum[x]) MaxNum[x]=ss-SonNum[x];
if(rootx>MaxNum[x]) root=x,rootx=MaxNum[x];
}
void GetDis(int x,int father,int dis)
{
int v;
Dis[++dlen]=dis;
for(int i=Last[x];~i;i=Edge[i].last)
{
v=Edge[i].to; if(v==father|| Vis[v]) continue;
GetDis(v,x,dis+Edge[i].len);
}
}
ll Count(int x,int dis)
{
ll ret=0;
for(int i=0;i<=dlen;++i) Dis[i]=0;
dlen=0;
GetDis(x,0,dis);
/*
for(int i=1;i<=dlen;++i)
for(int j=i+1;j<=dlen;++j)
{
if(!IsPrime[Dis[i]+Dis[j]]) ++ret;
}
*/
//memset(Num,0,sizeof(Num));
MaxLen=0;
for(int i=1;i<=dlen;++i)
{
++Num[Dis[i]]; if(Dis[i]>MaxLen) MaxLen=Dis[i];
}
len=1; while(len<=2*MaxLen) len<<=1;
FFT(Num,MaxLen+1,Res);
for(int i=1;i<=dlen;++i)
{
--Res[Dis[i]+Dis[i]];
}
MaxLen<<=1;
for(int i=0;i<=len;++i)
{
Res[i]/=2;
}
for(int i=0;i<prime_num && Prime[i]<=MaxLen;++i)
{
ret+=Res[Prime[i]];
}
for(int i=1;i<=dlen;++i)
{
Num[Dis[i]]=0;
}
return ret;
}
void Solve(int x)
{
int v;
ans+=Count(x,0);
Vis[x]=true;
for(int i=Last[x];~i;i=Edge[i].last)
{
v=Edge[i].to; if(Vis[v]) continue;
ans-=Count(v,Edge[i].len);
ss=SonNum[v]; rootx=INT_MAX; root=0;
GetRoot(v,x);
Solve(root);
}
}
void Work()
{
rootx=INT_MAX; ss=n; root=0;
GetRoot(1,0);
Solve(root);
}
void Write()
{
ll tmp=(ll)n*(n-1)/2;
//printf("%.f\n",tmp);
printf("%.7f\n",(double)ans/tmp);
}
int main()
{
CreatPrime();
//while(~scanf("%d",&n))
//{
Init();
Read();
Work();
Write();
Clear();
//}
return 0;
}