题意:给你一棵节点数为n的树,随机地在树上的任意两个点连一条边,给你m个询问,每次询问两个点,问连一条边后如果这两个点能在简单环中,简单环的期望是多少。
简单环即这两个点在一个环上,这个环是没有重边的。
思路:这里先设置几个变量dep[i]:i节点的深度,这里记dep[0]=0,dep[1]=1;sz[i]:i节点的子树的节点总数;f[i][j]:i节点的2^j倍父亲;sdown[i]:i节点子树中的所有点到i节点的距离和;sall[i]:所有点到i节点的距离和;t=lca(u,v).
先考虑lca(u,v)!=u && lca(u,v)!=v的情况,想要使得u,v都在简单环中,那么连边的两个端点一定是一个在u的子树中,另一个在v的子树中,且连边的方案数为sz[u]*sz[v],那么我们得到的期望值是sdown[u]/sz[u]+sdown[v]/sz[v]+1+dep[u]+dep[v]-2*dep[t].这里dep[u]+dep[v]-2*dep[t]+1是每一个形成的简单环都有的长度,所以可以先加上去.
然后考虑lca(u,v)==u || lca(u,v)==v的情况,不妨假设lca(u,v)=v,那么连边的两个端点一端一定在u的子树中,另一端在v的上面,即树上的所有点除去不包括u这个节点的子树,我们得到的期望值是sdown[u]/sz[u]+(sall[v]-sdown[v1]-sz[v1])/(n-sz[v1]) (v1是u,v路径上v的子节点).
第一次dfs先求出sdown[i],然后第二次dfs就能求出sall[i]了.
- #include<iostream>
- #include<stdio.h>
- #include<stdlib.h>
- #include<string.h>
- #include<math.h>
- #include<vector>
- #include<map>
- #include<set>
- #include<queue>
- #include<stack>
- #include<string>
- #include<algorithm>
- using namespace std;
- typedef long long ll;
- typedef long double ldb;
- #define inf 99999999
- #define pi acos(-1.0)
- #define maxn 100050
-
- int sz[maxn],dep[maxn],f[maxn][23];
- ll sdown[maxn],sall[maxn];
- int n;
- struct edge{
- int to,next;
- }e[2*maxn];
- int first[maxn];
- void dfs1(int u,int father,int deep)
- {
- int i,j,v;
- dep[u]=dep[father]+1;
- sz[u]=1;sdown[u]=0;
- for(i=first[u];i!=-1;i=e[i].next){
- v=e[i].to;
- if(v==father)continue;
- f[v][0]=u;
- dfs1(v,u,dep[u]);
- sz[u]+=sz[v];
- sdown[u]+=sdown[v]+sz[v];
- }
- }
-
- void dfs2(int u,int father)
- {
- int i,j,v;
- for(i=first[u];i!=-1;i=e[i].next){
- v=e[i].to;
- if(v==father)continue;
- sall[v]=sall[u]+n-2*sz[v];
- dfs2(v,u);
- }
- }
- void init()
- {
- dep[0]=0;
- dfs1(1,0,0);
- sall[1]=sdown[1];
- dfs2(1,0);
- }
- int lca(int x,int y){
- int i;
- if(dep[x]<dep[y]){
- swap(x,y);
- }
- for(i=20;i>=0;i--){
- if(dep[f[x][i] ]>=dep[y]){
- x=f[x][i];
- }
- }
- if(x==y)return x;
- for(i=20;i>=0;i--){
- if(f[x][i]!=f[y][i]){
- x=f[x][i];y=f[y][i];
- }
- }
- return f[x][0];
- }
- int up(int u,int deep)
- {
- int i,j;
- for(i=20;i>=0;i--){
- if((1<<i)<=deep){
- u=f[u][i];
- deep-=(1<<i);
- }
- }
- return u;
-
- }
- int main()
- {
- int m,i,j,tot,c,d,v,u,k;
- double sum;
- while(scanf("%d%d",&n,&m)!=EOF)
- {
- tot=0;
- memset(first,-1,sizeof(first));
- for(i=1;i<=n-1;i++){
- scanf("%d%d",&c,&d);
- tot++;
- e[tot].next=first[c];e[tot].to=d;
- first[c]=tot;
-
- tot++;
- e[tot].next=first[d];e[tot].to=c;
- first[d]=tot;
- }
- init();
- for(k=1;k<=20;k++){
- for(i=1;i<=n;i++){
- f[i][k]=f[f[i][k-1]][k-1];
- }
- }
- for(i=1;i<=m;i++){
- scanf("%d%d",&u,&v);
- int t=lca(u,v);
- sum=(double)(dep[u]+dep[v]-2*dep[t])+1;
- if(t==u || t==v){
- if(t==u)swap(u,v);
- int v1=up(u,dep[u]-dep[v]-1);
- ll num1=sall[v]-sdown[v1]-sz[v1];
- sum+=(double)sdown[u]/(double)sz[u]+(double)(num1)/(double)(n-sz[v1]);
- printf("%.10f\n",sum);
- }
- else{
- sum+=(double)sdown[u]/(double)sz[u]+(double)sdown[v]/(double)sz[v];
- printf("%.10f\n",sum);
- }
- }
- }
- return 0;
- }