题目链接:点击打开链接
给定n c表示2棵树的节点数
下面n-1行给出第一棵树
下面c-1行给出第二棵树
问:
任选2个点,建一条边把这2棵树连起来后,整个图最大的链的长度的期望是多少。
思路:
首先给2棵树 X, Y各自求树的直径,
然后计算出X中每个点i 距离X上的最远点的距离(保存在X.dp[i]中)
同理计算出Y的每个点的最远距离。
那么当新建的这条边连着 X.i - Y.j 时,则最长链就是 max( X.dp[i] + Y.dp[j]+1, X.len, Y.len);
X.len 就是X树的直径。
然后我们对X.dp排个序找一下规律就好了。。
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<bits/stdc++.h>
template <class T>
inline bool rd(T &ret) {
char c; int sgn;
if(c=getchar(),c==EOF) return 0;
while(c!='-'&&(c<'0'||c>'9')) c=getchar();
sgn=(c=='-')?-1:1;
ret=(c=='-')?0:(c-'0');
while(c=getchar(),c>='0'&&c<='9') ret=ret*10+(c-'0');
ret*=sgn;
return 1;
}
template <class T>
inline void pt(T x) {
if (x <0) {
putchar('-');
x = -x;
}
if(x>9) pt(x/10);
putchar(x%10+'0');
}
using namespace std;
typedef long long ll;
const int N = 40010;
struct node{
struct Edge{
int to, nex;
}edge[N<<1];
int head[N], edgenum;
void init(){memset(head, -1, sizeof head); edgenum = 0;}
void add(int u, int v){
Edge E = {v, head[u]};
edge[edgenum] = E;
head[u] = edgenum++;
}
int dis[N], dp[N], pre[N], Stack[N], top, len;
bool vis[N];
int BFS(int x){
memset(vis,0,sizeof(vis));
queue<int>q;
q.push(x); vis[x] = 1;
dis[x]=0; pre[x] = -1;
int ans = x;
while(!q.empty()){
int u=q.front();q.pop();
for(int i = head[u]; ~i; i = edge[i].nex) {
int v = edge[i].to;
if(vis[v])continue;
dis[v]=dis[u]+1;
pre[v] = u;
if(dis[v] > dis[ans]) ans = v;
vis[v]=1;q.push(v);
}
}
return ans;
}
void dfs(int u){
vis[u] = 1;
for(int i = head[u];~i;i=edge[i].nex){
int v = edge[i].to;
if(vis[v]) continue;
dp[v] = dp[u]+1;
dfs(v);
}
}
void work(){
int E = BFS(1);
int S = BFS(E);
len = dis[S];
top = 0;
memset(vis, 0, sizeof vis);
int u = S;
while(u!=-1){
Stack[top++] = u;
vis[u] = 1;
u = pre[u];
}
for(int i = 0; i < top; i++){
int u = Stack[i];
dp[u] = max(dis[u], len - dis[u]);
dfs(u);
}
}
}X, Y;
int n, q;
int a[N], b[N];
ll sum[N];
void input(){
X.init(); Y.init();
for(int i = 1, u, v; i < n; i++){
rd(u); rd(v);
X.add(u,v); X.add(v,u);
}
for(int i = 1, u, v; i < q; i++){
rd(u); rd(v);
Y.add(u,v); Y.add(v,u);
}
}
int main(){
int L, idx;
ll tot;
while(cin>>n>>q){
input();
X.work(); Y.work();
for (int i = 1; i <= n; ++i)
a[i] = X.dp[i]+1;
for (int i = 1; i <= q; ++i)
b[i] = Y.dp[i];
L = max(X.len, Y.len);
sort(a+1, a+1+n);
sort(b+1, b+1+q);
tot = sum[0] = 0;
for (int i = 1; i <= q; ++i)
sum[i] = sum[i-1]+b[i];
idx = q+1;
for (int i = 1; i <= n; ++i) {
while (idx-1>=1 && b[idx-1]+a[i]>L)
-- idx;
if (idx==1) {
tot += sum[q] + (ll)q * a[i];
} else if (idx == q+1) {
tot += (ll)L * q;
} else {
tot += (ll)L * (idx-1);
tot += (ll)sum[q]-sum[idx-1] + (ll)(q-idx+1)*a[i];
}
}
printf("%.3f\n", (double)tot / n / q);
}
return 0;
}