传送门
分析
首先需要推出来一个结论:
m
i
n
(
d
i
s
(
u
1
,
v
)
+
d
i
s
(
u
2
,
v
)
+
d
i
s
(
u
3
,
v
)
)
=
d
i
s
(
u
1
,
u
2
)
+
d
i
s
(
u
2
,
u
3
)
+
d
i
s
(
u
3
,
u
1
)
∗
0.5
min(dis(u_1,v) + dis(u_2,v) + dis(u_3,v)) = dis(u_1,u_2) + dis(u_2,u_3) + dis(u_3,u_1) * 0.5
min(dis(u1,v)+dis(u2,v)+dis(u3,v))=dis(u1,u2)+dis(u2,u3)+dis(u3,u1)∗0.5
简单证明一下,考虑树上路径唯一,任取一点
v
v
v,
a
a
a到
b
b
b和
c
c
c的距离都会经过
a
v
av
av,
a
v
∗
2
av*2
av∗2,同理可得,
b
v
∗
2
bv*2
bv∗2,
c
v
∗
2
cv*2
cv∗2
这样,这个问题就转化成了求
E
(
d
i
s
(
u
1
,
u
2
)
+
d
i
s
(
u
2
,
u
3
)
+
d
i
s
(
u
3
,
u
1
)
∗
0.5
)
E(dis(u_1,u_2) + dis(u_2,u_3) + dis(u_3,u_1) * 0.5)
E(dis(u1,u2)+dis(u2,u3)+dis(u3,u1)∗0.5)
∑
∑
d
i
s
(
x
,
y
)
n
u
m
1
∗
n
u
m
2
\frac{\sum\sum dis(x,y)}{num_1 * num_2}
num1∗num2∑∑dis(x,y)
剩下的算一下每条边的贡献,
D
P
DP
DP就可以啦
代码
#pragma GCC optimize(3)
#include <bits/stdc++.h>
#define debug(x) cout<<#x<<":"<<x<<endl;
#define dl(x) printf("%lld\n",x);
#define di(x) printf("%d\n",x);
#define _CRT_SECURE_NO_WARNINGS
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef vector<int> VI;
const int INF = 0x3f3f3f3f;
const int N = 2e5 + 10,M = N * 2;
const ll mod = 1000000007;
const double eps = 1e-9;
const double PI = acos(-1);
template<typename T>inline void read(T &a) {
char c = getchar(); T x = 0, f = 1; while (!isdigit(c)) {if (c == '-')f = -1; c = getchar();}
while (isdigit(c)) {x = (x << 1) + (x << 3) + c - '0'; c = getchar();} a = f * x;
}
int gcd(int a, int b) {return (b > 0) ? gcd(b, a % b) : a;}
int h[N],e[M],ne[M],w[M],idx;
ll num[10];
ll sz[10][N];
ll d[10];
int n;
void add(int x,int y,int z){
ne[idx] = h[x],e[idx] = y,w[idx] = z,h[x] = idx++;
}
void cal(int u,int fa){
for(int i = h[u];~i;i = ne[i]){
int j = e[i];
if(j == fa) continue;
cal(j,u);
for(int p = 0;p < 3;p++) sz[p][u] += sz[p][j];
}
}
void dfs(int u,int fa){
for(int i = h[u];~i;i = ne[i]){
int j = e[i];
if(j == fa) continue;
for(int p = 0;p < 3;p++){
d[p] += w[i] * sz[p][j] * (num[(p + 1) % 3] - sz[(p + 1) % 3][j]) + w[i] * sz[(p + 1) % 3][j] * (num[p] - sz[p][j]);
}
dfs(j,u);
}
}
int main() {
memset(h,-1,sizeof h);
read(n);
for(int i = 1;i < n;i++) {
int a,b,c;
read(a),read(b),read(c);
add(a,b,c),add(b,a,c);
}
for(int i = 0;i < 3;i++){
read(num[i]);
for(int j = 1;j <= num[i];j++){
int x;
read(x);
sz[i][x]++;
}
}
cal(1,0);
dfs(1,0);
double ans = 0;
ans = 1.0 * d[0] / (num[0] * num[1]);
ans += 1.0 * d[1] / (num[1] * num[2]);
ans += 1.0 * d[2] / (num[2] * num[0]);
ans /= 2;
printf("%.9lf\n",ans);
return 0;
}