题目链接:https://www.codechef.com/problems/PRIMEDST
题目大意:给出一棵树,要你求有多少对点满足两点之间的距离为素数。
题目思路:要求树上两点之间的距离,很容易想到用树分治来解决,但如果要一一枚举的话,复杂度就太高了,所以我们可以借助FFT来降低复杂度,分治过程中借助FFT求出一棵子树中的任意两点到当前根节点的距离之和,再判断一下素数即可。
具体实现看代码:
#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lowbit(x) x&-x
#define pb push_back
#define MP make_pair
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define fuck(x) cout<<"["<<x<<"]"<<endl
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int>pii;
//head
const int MX = 1e5+5;
const int inf = 0x3f3f3f3f;
int n,k,MIN;
int cnt;
int sz[MX],dis[MX],num[MX];
bool vis[MX];
struct edge{
int v,nxt;
}E[MX<<1];
int head[MX],tot;
int prime[MX],pcnt;
bool is_prime[MX];
void prime_init(){
for(int i = 2;i < MX;i++) is_prime[i] = true;
pcnt = 0;
for(int i = 2;i < MX;i++){
if(is_prime[i])
prime[pcnt++] = i;
for(int j = 0;j < pcnt;j++){
if(i*prime[j] >= MX) break;
is_prime[i*prime[j]] = false;
if(i%prime[j] == 0)
break;
}
}
}
void init(){
clr(vis);
memset(head,-1,sizeof(head));
tot=0;
}
void add_edge(int u,int v){
E[tot].v=v;E[tot].nxt=head[u];
head[u]=tot++;
}
const double pi = acos(-1.0);
int len,mx;//开大4倍
ll res[MX<<2];
struct Complex {
double r,i;
Complex(double r=0,double i=0):r(r),i(i) {};
Complex operator+(const Complex &rhs) {return Complex(r + rhs.r,i + rhs.i);}
Complex operator-(const Complex &rhs) {return Complex(r - rhs.r,i - rhs.i);}
Complex operator*(const Complex &rhs) {return Complex(r*rhs.r - i*rhs.i,i*rhs.r + r*rhs.i);}
} va[MX<<2],vb[MX<<2];
void rader(Complex F[],int len) { //len = 2^M,reverse F[i] with F[j] j为i二进制反转
int j = len >> 1;
for(int i = 1; i < len - 1; ++i) {
if(i < j) swap(F[i],F[j]); // reverse
int k = len>>1;
while(j>=k) {
j -= k;
k >>= 1;
}
if(j < k) j += k;
}
}
void FFT(Complex F[],int len,int t) {
rader(F,len);
for(int h=2; h<=len; h<<=1) {
Complex wn(cos(-t*2*pi/h),sin(-t*2*pi/h));
for(int j=0; j<len; j+=h) {
Complex E(1,0); //旋转因子
for(int k=j; k<j+h/2; ++k) {
Complex u = F[k];
Complex v = E*F[k+h/2];
F[k] = u+v;
F[k+h/2] = u-v;
E=E*wn;
}
}
}
if(t==-1) //IDFT
for(int i=0; i<len; ++i)
F[i].r/=len;
}
void Conv(Complex a[],Complex b[],int len) { //求卷积
FFT(a,len,1);
FFT(b,len,1);
for(int i=0; i<len; ++i) a[i] = a[i]*b[i];
FFT(a,len,-1);
}
void work() {
Conv(va,vb,len);
for(int i=0; i<len; ++i)res[i]=va[i].r + 0.5;
}
void get_sz(int u,int fa){
sz[u]=1;
for(int i=head[u];~i;i=E[i].nxt){
int v=E[i].v;
if(v==fa || vis[v]) continue;
get_sz(v,u);
sz[u]+=sz[v];
}
}
void get_root(int u,int fa,int num,int &rt){
int MAX=num-sz[u];
for(int i=head[u];~i;i=E[i].nxt){
int v=E[i].v;
if(v==fa || vis[v]) continue;
get_root(v,u,num,rt);
MAX=max(MAX,sz[v]);
}
if(MAX<MIN){
MIN=MAX;
rt=u;
}
}
void dfs(int u,int fa,int d){
dis[cnt++]=d;
for(int i=head[u];~i;i=E[i].nxt){
int v=E[i].v;
if(v==fa || vis[v]) continue;
dfs(v,u,d+1);
}
}
ll cal(int u,int d){
cnt = 0;dfs(u,0,d);
mx = 0;
ll _res = 0;
for(int i = 0;i < cnt;i++){
num[dis[i]]++;
mx = max(mx,dis[i]);
}
len = 1;
while(len <= 2*mx) len<<=1;
for(int i = 0;i < len;i++){
if(i <= mx) va[i] = vb[i] = Complex(num[i],0);
else va[i] = vb[i] = Complex(0,0);
}
work();
for(int i = 0;i < cnt;i++) res[dis[i]+dis[i]]--;
for(int i = 0;i < len;i++) res[i] /= 2;
for(int i = 0;i < pcnt && prime[i] <= mx*2;i++)
_res += res[prime[i]];
for(int i = 0;i < cnt;i++) num[dis[i]]--;
return _res;
}
ll ans;
void solve(int u){
get_sz(u,0);
int num=sz[u],rt;
MIN=inf;
get_root(u,0,num,rt);
vis[rt] = 1;
ans += cal(rt,0);
for(int i = head[rt];~i;i=E[i].nxt){
int v = E[i].v;
if(vis[v]) continue;
ans -= cal(v,1);
solve(v);
}
}
int main(){
//FIN;
prime_init();
while(~scanf("%d",&n)){
init();
for(int i = 1;i < n;i++){
int u, v;
scanf("%d%d",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
ll all = (ll)n*(n-1)/2;
ans = 0;
solve(1);
printf("%.7f\n",1.0*ans/all);
}
return 0;
}