题意
有两棵树,现在会在两棵树上各选一点相连,问相连后树上最长距离的期望是多少
思路
先进行三遍dfs求出两棵树上每个点在它所在树上的最远距离,如果选择两个点进行合并,那么这个新的树上最长距离就是原来的树上最长距离和这两个点在原树上最远距离之和加1中的较大值,所以我们可以先对两个树上的最远距离排序,然后二分得到固定一个点,有多少种连接会是前者更大,多少种会是后者更大,然后进行计算,最后除以总情况数就好
代码
#include <cstdio>
#include <vector>
#include <algorithm>
#include <cmath>
using namespace std;
long long depth1[40001],depth2[40001];
vector<long long> mp[40001];
long long add[40001];
long long dfs(long long x,long long from,long long step,long long *d)
{
long long dep=-1;
for(long long i=0;i<mp[x].size();i++)
if(mp[x][i]!=from)
dep=max(dep,dfs(mp[x][i],x,step+1,d));
d[x]=max(d[x],max(step,dep+1));
return dep+1;
}
int main()
{
long long N,Q,x,y,t,maxt,p,pp,tp;
long long ans,maxx;
while(scanf("%lld%lld",&N,&Q)!=EOF)
{
for(long long i=0;i<N-1;i++)
{
scanf("%lld%lld",&x,&y);
mp[x].push_back(y);
mp[y].push_back(x);
}
t=0;
for(long long i=1;i<=N;i++)
if(mp[i].size()==1)
{
t=i;
break;
}
dfs(t,-1,0,depth1);
maxt=-1;
p=-1;
for(long long i=1;i<=N;i++)
if(mp[i].size()==1&&i!=t&&depth1[i]>maxt)
{
maxt=depth1[i];
p=i;
}
if(p!=-1)
dfs(p,-1,0,depth1);
tp=-1;
maxt=-1;
for(long long i=1;i<=N;i++)
if(mp[i].size()==1&&i!=t&&i!=p&&depth1[i]>maxt)
{
maxt=depth1[i];
tp=i;
}
if(tp!=-1)
dfs(tp,-1,0,depth1);
for(long long i=1;i<=N;i++)
mp[i].clear();
for(long long i=0;i<Q-1;i++)
{
scanf("%lld%lld",&x,&y);
mp[x].push_back(y);
mp[y].push_back(x);
}
t=0;
for(long long i=1;i<=Q;i++)
if(mp[i].size()==1)
{
t=i;
break;
}
dfs(t,-1,0,depth2);
maxt=-1;
p=-1;
for(long long i=1;i<=Q;i++)
if(mp[i].size()==1&&i!=t&&depth2[i]>maxt)
{
maxt=depth2[i];
p=i;
}
if(p!=-1)
dfs(p,-1,0,depth2);
tp=-1;
maxt=-1;
for(long long i=1;i<=Q;i++)
if(mp[i].size()==1&&i!=t&&i!=p&&depth2[i]>maxt)
{
maxt=depth2[i];
tp=i;
}
if(tp!=-1)
dfs(tp,-1,0,depth2);
for(long long i=1;i<=Q;i++)
mp[i].clear();
ans=0;
sort(depth1,depth1+N+1);
sort(depth2,depth2+Q+1);
maxx=max(depth1[N],depth2[Q]);
for(long long i=1;i<=Q;i++)
add[i]=add[i-1]+depth2[i];
for(long long i=1;i<=N;i++)
{
pp=upper_bound(depth2+1,depth2+Q+1,maxx-depth1[i]-1)-depth2-1;
ans+=maxx*pp;
ans+=(add[Q]-add[pp])+(depth1[i]+1)*(Q-pp);
}
printf("%.3f\n",double(ans)/(N*Q));
for(long long i=1;i<=N;i++)
depth1[i]=0;
for(long long i=1;i<=Q;i++)
depth2[i]=0;
}
return 0;
}