https://codeforces.com/gym/102798/problem/C
比赛的时候枚举3个点的中心点然后嗯算方案数,写了100多行的dfs计数。。。然后wa 3
然而这题只要想到,3个点到中心点的最小距离,要么一条链在中间点上,要么在3叉路口上,这两种情况都可以拆成从a-b,b-c,a-c的路径之和/2
然后这个问题就转换成了3组独立问题,每组就是两个人任选2个点之间的期望距离之和
那么就直接dfs下去算每条边的贡献就行了,每条边的贡献就是连接的两端连通块中a,b的数量乘起来,a*b,b*a,a*c,c*a,b*c,c*b6种算一算就行了
md回去我要把我的嗯算给调出来
upd:
艹我嗯调出来了,就是枚举中心点,然后要么是3个分支上,要么是1个点在中心,另外两个在两个分支,要么是2个点在中心,另一个点在一个分支
计算的时候要存一下边长之和,还要存有多少种方案,然后算的时候还要容斥一下
#include<bits/stdc++.h>
using namespace std;
typedef __int128 ll;//__int128!!!
const int maxl=2e5+10;
int n,m[4],tot;
ll sumdis[4][4],num[4][maxl];
struct ed{int v,l;};
vector<ed> e[maxl];
bool in[4][maxl];
inline void prework()
{
scanf("%d",&n);
for(int i=1;i<=n-1;i++)
{
int u,v,l;
scanf("%d%d%d",&u,&v,&l);
e[u].push_back(ed{v,l});
e[v].push_back(ed{u,l});
}
for(int i=1;i<=3;i++)
{
scanf("%d",&m[i]);
for(int j=1;j<=m[i];j++)
{
int x;scanf("%d",&x);
in[i][x]=true;
}
}
}
inline void predfs(int u,int fa)
{
for(int i=1;i<=3;i++)
if(in[i][u])
num[i][u]=1;
for(ed ee:e[u])
if(ee.v!=fa)
{
int v=ee.v;
predfs(v,u);
for(int i=1;i<=3;i++)
for(int j=i+1;j<=3;j++)
sumdis[i][j]+=((m[i]-num[i][v])*num[j][v]+(m[j]-num[j][v])*num[i][v])*ee.l;
for(int i=1;i<=3;i++)
num[i][u]+=num[i][v];
}
}
inline void mainwork()
{
predfs(1,0);
}
inline void print()
{
double ans=0;
for(int i=1;i<=3;i++)
for(int j=i+1;j<=3;j++)
ans+=1.0*sumdis[i][j]/m[i]/m[j];
ans/=2;
printf("%.10f",ans);
}
int main()
{
prework();
mainwork();
print();
return 0;
}
#include<bits/stdc++.h>
using namespace std;
typedef __int128 ll;//__int128!!!
const int maxl=2e5+10;
int n,m[4],tot;
ll sumdis,o=1;
ll num[4][maxl],val[4][maxl];
ll sumnum[4][maxl],sum[4][maxl];
ll sum12[maxl],sum23[maxl],sum13[maxl];
ll num12[maxl],num23[maxl],num13[maxl];
struct ed{int v,l;};
vector<ed> e[maxl];
bool in[4][maxl];
inline void prework()
{
scanf("%d",&n);
for(int i=1;i<=n-1;i++)
{
int u,v,l;
scanf("%d%d%d",&u,&v,&l);
e[u].push_back(ed{v,l});
e[v].push_back(ed{u,l});
}
for(int i=1;i<=3;i++)
{
scanf("%d",&m[i]);
for(int j=1;j<=m[i];j++)
{
int x;scanf("%d",&x);
in[i][x]=true;
}
}
}
inline void predfs(int u,int fa,int fal)
{
for(int j=1;j<=3;j++)
if(in[j][u])
num[j][u]++;
for(ed ee:e[u])
if(ee.v!=fa)
{
predfs(ee.v,u,ee.l);
for(int j=1;j<=3;j++)
{
num[j][u]+=num[j][ee.v];
val[j][u]+=val[j][ee.v];
}
}
for(int j=1;j<=3;j++)
val[j][u]+=num[j][u]*fal;
}
inline void dfs(int u,int fa,ll fval1,ll fval2,ll fval3)
{
vector<ll> tmp[4],tmpnum[4];
tmp[1].push_back(fval1);tmp[2].push_back(fval2);tmp[3].push_back(fval3);
tmpnum[1].push_back(m[1]-num[1][u]);
tmpnum[2].push_back(m[2]-num[2][u]);
tmpnum[3].push_back(m[3]-num[3][u]);
for(ed ee:e[u])
if(ee.v!=fa)
for(int i=1;i<=3;i++)
{
tmp[i].push_back(val[i][ee.v]);
tmpnum[i].push_back(num[i][ee.v]);
}
for(int i=1;i<=3;i++)
{
sum[i][u]=0;sumnum[i][u]=0;
for(ll x:tmp[i])
sum[i][u]+=x;
for(ll x:tmpnum[i])
sumnum[i][u]+=x;
}
//12-...-3
if(in[1][u] && in[2][u])
sumdis+=sum[3][u];
if(in[1][u] && in[3][u])
sumdis+=sum[2][u];
if(in[2][u] && in[3][u])
sumdis+=sum[1][u];
int len=tmp[1].size();
sum12[u]=sum23[u]=sum13[u]=0;
num12[u]=num23[u]=num13[u]=0;
for(int j=0;j<len;j++)
sum12[u]+=tmp[1][j]*(sumnum[2][u]-tmpnum[2][j])+tmpnum[1][j]*(sum[2][u]-tmp[2][j]);
for(int j=0;j<len;j++)
{
sum23[u]+=tmp[2][j]*(sumnum[3][u]-tmpnum[3][j])+tmpnum[2][j]*(sum[3][u]-tmp[3][j]);
num23[u]+=tmpnum[2][j]*(sumnum[3][u]-tmpnum[3][j]);
}
for(int j=0;j<len;j++)
sum13[u]+=tmp[1][j]*(sumnum[3][u]-tmpnum[3][j])+tmpnum[1][j]*(sum[3][u]-tmp[3][j]);
// 1-..-2-..-3
if(in[1][u])
sumdis+=sum23[u];
if(in[2][u])
sumdis+=sum13[u];
if(in[3][u])
sumdis+=sum12[u];
// -3
// 1-z-2
for(int j=0;j<len;j++)
{
ll x=tmp[1][j]*(num23[u]-tmpnum[2][j]*(sumnum[3][u]-tmpnum[3][j])-tmpnum[3][j]*(sumnum[2][u]-tmpnum[2][j]));
x+=tmpnum[1][j]*(sum23[u]-(tmp[2][j]*(sumnum[3][u]-tmpnum[3][j])+tmpnum[2][j]*(sum[3][u]-tmp[3][j]))-(tmp[3][j]*(sumnum[2][u]-tmpnum[2][j])+tmpnum[3][j]*(sum[2][u]-tmp[2][j])));
sumdis+=x;
}
for(int i=1;i<=3;i++)
tmp[i].clear(),tmpnum[i].clear();
for(ed ee:e[u])
if(ee.v!=fa)
{
int v=ee.v;
ll nf1=sum[1][u]-val[1][v]+(m[1]-num[1][v])*ee.l;
ll nf2=sum[2][u]-val[2][v]+(m[2]-num[2][v])*ee.l;
ll nf3=sum[3][u]-val[3][v]+(m[3]-num[3][v])*ee.l;
dfs(v,u,nf1,nf2,nf3);
}
}
inline void mainwork()
{
predfs(1,0,0);
sumdis=0;
dfs(1,0,0,0,0);
}
inline void print()
{
double ans=(long double)sumdis/m[1]/m[2]/m[3];
printf("%.10f",ans);
}
int main()
{
prework();
mainwork();
print();
return 0;
}